dense_stack_tds¶
- class tensordict.dense_stack_tds(td_list: Union[Sequence[TensorDictBase], LazyStackedTensorDict], dim: Optional[int] = None)¶
Densely stack a list of
TensorDictBase
objects (or aLazyStackedTensorDict
) given that they have the same structure.This function is called with a list of
TensorDictBase
(either passed directly or obtrained from aLazyStackedTensorDict
). Instead of callingtorch.stack(td_list)
, which would return aLazyStackedTensorDict
, this function expands the first element of the input list and stacks the input list onto that element. This works only when all the elements of the input list have the same structure. TheTensorDictBase
returned will have the same type of the elements of the input list.This function is useful when some of the
TensorDictBase
objects that need to be stacked areLazyStackedTensorDict
or haveLazyStackedTensorDict
among entries (or nested entries). In those cases, callingtorch.stack(td_list).to_tensordict()
is infeasible. Thus, this function provides an alternative for densely stacking the list provided.- Parameters:
td_list (List of TensorDictBase or LazyStackedTensorDict) – the tds to stack.
dim (int, optional) – the dimension to stack them. If td_list is a LazyStackedTensorDict, it will be retrieved automatically.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict import dense_stack_tds >>> from tensordict.tensordict import assert_allclose_td >>> td0 = TensorDict({"a": torch.zeros(3)},[]) >>> td1 = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[]) >>> td_lazy = torch.stack([td0, td1], dim=0) >>> td_container = TensorDict({"lazy": td_lazy}, []) >>> td_container_clone = td_container.clone() >>> td_stack = torch.stack([td_container, td_container_clone], dim=0) >>> td_stack LazyStackedTensorDict( fields={ lazy: LazyStackedTensorDict( fields={ a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, 2]), device=None, is_shared=False, stack_dim=0)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim TensorDict( fields={ lazy: LazyStackedTensorDict( fields={ a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ 1 -> b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 2]), device=None, is_shared=False, stack_dim=1)}, batch_size=torch.Size([2]), device=None, is_shared=False) # Note that # (1) td_stack is now a TensorDict # (2) this has pushed the stack_dim of "lazy" (0 -> 1) # (3) this has revealed the exclusive keys. >>> assert_allclose_td(td_stack, dense_stack_tds([td_container, td_container_clone], dim=0)) # This shows it is the same to pass a list or a LazyStackedTensorDict