TensorDict¶
- class tensordict.TensorDict(source: Optional[Union[T, dict[tensordict._nestedkey.NestedKey, torch.Tensor]]] = None, batch_size: Optional[Union[Sequence[int], Size, int]] = None, device: Optional[Union[device, str, int]] = None, names: Optional[Sequence[str]] = None, non_blocking: Optional[bool] = None, lock: bool = False, **kwargs: dict[str, Any] | None)¶
A batched dictionary of tensors.
TensorDict is a tensor container where all tensors are stored in a key-value pair fashion and where each element shares the same first
N
leading dimensions shape, where is an arbitrary number withN >= 0
.Additionally, if the tensordict has a specified device, then each element must share that device.
TensorDict instances support many regular tensor operations with the notable exception of algebraic operations:
operations on shape: when a shape operation is called (indexing, reshape, view, expand, transpose, permute, unsqueeze, squeeze, masking etc), the operations is done as if it was executed on a tensor of the same shape as the batch size then expended to the right, e.g.:
>>> td = TensorDict({'a': torch.zeros(3, 4, 5)}, batch_size=[3, 4]) >>> # returns a TensorDict of batch size [3, 4, 1]: >>> td_unsqueeze = td.unsqueeze(-1) >>> # returns a TensorDict of batch size [12] >>> td_view = td.view(-1) >>> # returns a tensor of batch size [12, 4] >>> a_view = td.view(-1).get("a")
casting operations: a TensorDict can be cast on a different device using
>>> td_cpu = td.to("cpu") >>> dictionary = td.to_dict()
A call of the .to() method with a dtype will return an error.
Cloning (
clone()
), contiguous (contiguous()
);Reading: td.get(key), td.get_at(key, index)
Content modification:
td.set(key, value)
,td.set_(key, value)
,td.update(td_or_dict)
,td.update_(td_or_dict)
,td.fill_(key, value)
,td.rename_key_(old_name, new_name)
, etc.Operations on multiple tensordicts: torch.cat(tensordict_list, dim), torch.stack(tensordict_list, dim), td1 == td2, td.apply(lambda x+y, other_td) etc.
- Parameters:
source (TensorDict or Dict[NestedKey, Union[Tensor, TensorDictBase]]) – a data source. If empty, the tensordict can be populated subsequently. A
TensorDict
can also be built via a sequence of keyword arguments, as it is the case fordict(...)
.batch_size (iterable of int, optional) – a batch size for the tensordict. The batch size can be modified subsequently as long as it is compatible with its content. If not batch-size is provided, an empty batch-size is assumed (it is not inferred automatically from the data). To automatically set the batch-size, refer to
auto_batch_size_()
.device (torch.device or compatible type, optional) – a device for the TensorDict. If provided, all tensors will be stored on that device. If not, tensors on different devices are allowed.
names (lsit of str, optional) – the names of the dimensions of the tensordict. If provided, its length must match the one of the
batch_size
. Defaults toNone
(no dimension name, orNone
for every dimension).non_blocking (bool, optional) – if
True
and a device is passed, the tensordict is delivered without synchronization. This is the fastest option but is only safe when casting from cpu to cuda (otherwise a synchronization call must be implemented by the user). IfFalse
is passed, every tensor movement will be done synchronously. IfNone
(default), the device casting will be done asynchronously but a synchronization will be executed after creation if required. This option should generally be faster thanFalse
and potentially slower thanTrue
.lock (bool, optional) – if
True
, the resulting tensordict will be locked.
Examples
>>> import torch >>> from tensordict import TensorDict >>> source = {'random': torch.randn(3, 4), ... 'zeros': torch.zeros(3, 4, 5)} >>> batch_size = [3] >>> td = TensorDict(source, batch_size=batch_size) >>> print(td.shape) # equivalent to td.batch_size torch.Size([3]) >>> td_unqueeze = td.unsqueeze(-1) >>> print(td_unqueeze.get("zeros").shape) torch.Size([3, 1, 4, 5]) >>> print(td_unqueeze[0].shape) torch.Size([1]) >>> print(td_unqueeze.view(-1).shape) torch.Size([3]) >>> print((td.clone()==td).all()) True
- abs() T ¶
Computes the absolute value of each element of the TensorDict.
- abs_() T ¶
Computes the absolute value of each element of the TensorDict in-place.
- acos() T ¶
Computes the
acos()
value of each element of the TensorDict.
- acos_() T ¶
Computes the
acos()
value of each element of the TensorDict in-place.
- add(other: tensordict.base.TensorDictBase | torch.Tensor, *, alpha: Optional[float] = None, default: Optional[Union[str, Tensor]] = None) TensorDictBase ¶
Adds
other
, scaled byalpha
, toself
.\[\text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i\]- Parameters:
other (TensorDictBase or torch.Tensor) – the tensor or TensorDict to add to
self
.- Keyword Arguments:
alpha (Number, optional) – the multiplier for
other
.default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- add_(other: tensordict.base.TensorDictBase | float, *, alpha: Optional[float] = None)¶
In-place version of
add()
.Note
In-place
add
does not supportdefault
keyword argument.
- addcdiv(other1: tensordict.base.TensorDictBase | torch.Tensor, other2: tensordict.base.TensorDictBase | torch.Tensor, value: float | None = 1)¶
Performs the element-wise division of
other1
byother2
, multiplies the result by the scalarvalue
and adds it toself
.\[\text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i}\]The shapes of the elements of
self
,other1
, andother2
must be broadcastable.For inputs of type FloatTensor or DoubleTensor,
value
must be a real number, otherwise an integer.- Parameters:
other1 (TensorDict or Tensor) – the numerator tensordict (or tensor)
tensor2 (TensorDict or Tensor) – the denominator tensordict (or tensor)
- Keyword Arguments:
value (Number, optional) – multiplier for \(\text{tensor1} / \text{tensor2}\)
- addcmul(other1, other2, *, value: float | None = 1)¶
Performs the element-wise multiplication of
other1
byother2
, multiplies the result by the scalarvalue
and adds it toself
.\[\text{out}_i = \text{input}_i + \text{value} \times \text{other1}_i \times \text{other2}_i\]The shapes of
self
,other1
, andother2
must be broadcastable.For inputs of type FloatTensor or DoubleTensor,
value
must be a real number, otherwise an integer.- Parameters:
other1 (TensorDict or Tensor) – the tensordict or tensor to be multiplied
other2 (TensorDict or Tensor) – the tensordict or tensor to be multiplied
- Keyword Arguments:
value (Number, optional) – multiplier for \(other1 .* other2\)
- all(dim: Optional[int] = None) bool | tensordict.base.TensorDictBase ¶
Checks if all values are True/non-null in the tensordict.
- Parameters:
dim (int, optional) – if
None
, returns a boolean indicating whether all tensors return tensor.all() == True If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.
- any(dim: Optional[int] = None) bool | tensordict.base.TensorDictBase ¶
Checks if any value is True/non-null in the tensordict.
- Parameters:
dim (int, optional) – if
None
, returns a boolean indicating whether all tensors return tensor.any() == True. If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.
- apply(fn: Callable, *others: T, batch_size: Optional[Sequence[int]] = None, device: torch.device | None = _NoDefault.ZERO, names: Optional[Sequence[str]] = _NoDefault.ZERO, inplace: bool = False, default: Any = _NoDefault.ZERO, filter_empty: Optional[bool] = None, propagate_lock: bool = False, call_on_nested: bool = False, out: Optional[TensorDictBase] = None, **constructor_kwargs) Optional[T] ¶
Applies a callable to all values stored in the tensordict and sets them in a new tensordict.
The callable signature must be
Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]
.- Parameters:
fn (Callable) – function to be applied to the tensors in the tensordict.
*others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The
fn
argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through thedefault
keyword argument.
- Keyword Arguments:
batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The
batch_size
argument should match the batch_size after the transformation. This is a keyword only argument.device (torch.device, optional) – the resulting device, if any.
names (list of str, optional) – the new dimension names, in case the batch_size is modified.
inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.
default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.
filter_empty (bool, optional) – if
True
, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Non-tensor data is considered as a leaf and thereby will be kept in the tensordict even if left untouched by the function. Defaults toFalse
for backward compatibility.propagate_lock (bool, optional) – if
True
, a locked tensordict will produce another locked tensordict. Defaults toFalse
.call_on_nested (bool, optional) –
if
True
, the function will be called on first-level tensors and containers (TensorDict or tensorclass). In this scenario,func
is responsible of propagating its calls to nested levels. This allows a fine-grained behaviour when propagating the calls to nested tensordicts. IfFalse
, the function will only be called on leaves, andapply
will take care of dispatching the function to all leaves.>>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]}) >>> def mean_tensor_only(val): ... if is_tensor_collection(val): ... raise RuntimeError("Unexpected!") ... return val.mean() >>> td_mean = td.apply(mean_tensor_only) >>> def mean_any(val): ... if is_tensor_collection(val): ... # Recurse ... return val.apply(mean_any, call_on_nested=True) ... return val.mean() >>> td_mean = td.apply(mean_any, call_on_nested=True)
out (TensorDictBase, optional) –
a tensordict where to write the results. This can be used to avoid creating a new tensordict:
>>> td = TensorDict({"a": 0}) >>> td.apply(lambda x: x+1, out=td) >>> assert (td==1).all()
Warning
If the operation executed on the tensordict requires multiple keys to be accessed for a single computation, providing an
out
argument equal toself
can cause the operation to provide silently wrong results. For instance:>>> td = TensorDict({"a": 1, "b": 1}) >>> td.apply(lambda x: x+td["a"])["b"] # Right! tensor(2) >>> td.apply(lambda x: x+td["a"], out=td)["b"] # Wrong! tensor(3)
**constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.
- Returns:
a new tensordict with transformed_in tensors.
Example
>>> td = TensorDict({ ... "a": -torch.ones(3), ... "b": {"c": torch.ones(3)}}, ... batch_size=[3]) >>> td_1 = td.apply(lambda x: x+1) >>> assert (td_1["a"] == 0).all() >>> assert (td_1["b", "c"] == 2).all() >>> td_2 = td.apply(lambda x, y: x+y, td) >>> assert (td_2["a"] == -2).all() >>> assert (td_2["b", "c"] == 2).all()
Note
If
None
is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, []) >>> def filter(tensor): ... if tensor == 1: ... return tensor >>> td.apply(filter) TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
Note
The apply method will return an
TensorDict
instance, regardless of the input type. To keep the same type, one can execute>>> out = td.clone(False).update(td.apply(...))
- apply_(fn: Callable, *others, **kwargs) T ¶
Applies a callable to all values stored in the tensordict and re-writes them in-place.
- Parameters:
fn (Callable) – function to be applied to the tensors in the tensordict.
*others (sequence of TensorDictBase, optional) – the other tensordicts to be used.
Keyword Args: See
apply()
.- Returns:
self or a copy of self with the function applied
- asin() T ¶
Computes the
asin()
value of each element of the TensorDict.
- asin_() T ¶
Computes the
asin()
value of each element of the TensorDict in-place.
- atan() T ¶
Computes the
atan()
value of each element of the TensorDict.
- atan_() T ¶
Computes the
atan()
value of each element of the TensorDict in-place.
- auto_batch_size_(batch_dims: Optional[int] = None) T ¶
Sets the maximum batch-size for the tensordict, up to an optional batch_dims.
- Parameters:
batch_dims (int, optional) – if provided, the batch-size will be at most
batch_dims
long.- Returns:
self
Examples
>>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[]) >>> td.auto_batch_size_() >>> print(td.batch_size) torch.Size([3, 4]) >>> td.auto_batch_size_(batch_dims=1) >>> print(td.batch_size) torch.Size([3])
- auto_device_() T ¶
Automatically sets the device, if it is unique.
Returns: self with the edited
device
attribute.
- property batch_dims: int¶
Length of the tensordict batch size.
- Returns:
int describing the number of dimensions of the tensordict.
- property batch_size: Size¶
Shape (or batch_size) of a TensorDict.
The shape of a tensordict corresponds to the common first
N
dimensions of the tensors it contains, whereN
is an arbitrary number. The batch-size contrasts with the “feature size” which repesents the semantically relevant shapes of a tensor. For instance, a batch of videos may have shape[B, T, C, W, H]
, where[B, T]
is the batch-size (batch and time dimensions) and[C, W, H]
are the feature dimensions (channels and spacial dimensions).The
TensorDict
shape is controlled by the user upon initialization (ie, it is not inferred from the tensor shapes).The
batch_size
can be edited dynamically if the new size is compatible with the TensorDict content. For instance, setting the batch size to an empty value is always allowed.- Returns:
a
Size
object describing the TensorDict batch size.
Examples
>>> data = TensorDict({ ... "key 0": torch.randn(3, 4), ... "key 1": torch.randn(3, 5), ... "nested": TensorDict({"key 0": torch.randn(3, 4)}, batch_size=[3, 4])}, ... batch_size=[3]) >>> data.batch_size = () # resets the batch-size to an empty value
- bfloat16()¶
Casts all tensors to
torch.bfloat16
.
- bool()¶
Casts all tensors to
torch.bool
.
- bytes(*, count_duplicates: bool = True) int ¶
Counts the number of bytes of the contained tensors.
- Keyword Arguments:
count_duplicates (bool) – Whether to count duplicated tensor as independent or not. If
False
, only strictly identical tensors will be discarded (same views but different ids from a common base tensor will be counted twice). Defaults to True (each tensor is assumed to be a single copy).
- classmethod cat(input, dim=0, *, out=None)¶
Concatenates tensordicts into a single tensordict along the given dimension.
This call is equivalent to calling
torch.cat()
but is compatible with torch.compile.
- cat_from_tensordict(dim: int = 0, *, sorted: Optional[Union[bool, List[NestedKey]]] = None, out: Optional[Tensor] = None) Tensor ¶
Concatenates all entries of a tensordict in a single tensor.
- Parameters:
dim (int, optional) – the dimension along which the entries should be concatenated.
- Keyword Arguments:
sorted (bool or list of NestedKeys) – if
True
, the entries will be concatenated in alphabetical order. IfFalse
(default), the dict order will be used. Alternatively, a list of key names can be provided and the tensors will be concatenated accordingly. This incurs some overhead as the list of keys will be checked against the list of leaf names in the tensordict.out (torch.Tensor, optional) – an optional destination tensor for the cat operation.
- cat_tensors(*keys: NestedKey, out_key: NestedKey, dim: int = 0, keep_entries: bool = False) T ¶
Concatenates entries into a new entry and possibly remove the original values.
- Parameters:
keys (sequence of NestedKey) – entries to concatenate.
- Keyword Argument:
out_key (NestedKey): new key name for the concatenated inputs. keep_entries (bool, optional): if
False
, entries inkeys
will be deleted.Defaults to
False
.- dim (int, optional): the dimension along which the concatenation must occur.
Defaults to
0
.
Returns: self
Examples
>>> td = TensorDict(a=torch.zeros(1), b=torch.ones(1)) >>> td.cat_tensors("a", "b", out_key="c") >>> assert "a" not in td >>> assert (td["c"] == torch.tensor([0, 1])).all()
- ceil() T ¶
Computes the
ceil()
value of each element of the TensorDict.
- ceil_() T ¶
Computes the
ceil()
value of each element of the TensorDict in-place.
- chunk(chunks: int, dim: int = 0) tuple[tensordict.base.TensorDictBase, ...] ¶
Splits a tensordict into the specified number of chunks, if possible.
Each chunk is a view of the input tensordict.
- Parameters:
Examples
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 4, 2), ... }, batch_size=[3, 4]) >>> td0, td1 = td.chunk(dim=-1, chunks=2) >>> td0['x'] tensor([[[ 0, 1], [ 2, 3]], [[ 8, 9], [10, 11]], [[16, 17], [18, 19]]])
- clamp_max(other: tensordict.base.TensorDictBase | torch.Tensor, *, default: Optional[Union[str, Tensor]] = None) T ¶
Clamps the elements of
self
toother
if they’re superior to that value.- Parameters:
other (TensorDict or Tensor) – the other input tensordict or tensor.
- Keyword Arguments:
default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- clamp_max_(other: tensordict.base.TensorDictBase | torch.Tensor) T ¶
In-place version of
clamp_max()
.Note
Inplace
clamp_max
does not supportdefault
keyword argument.
- clamp_min(other: tensordict.base.TensorDictBase | torch.Tensor, default: Optional[Union[str, Tensor]] = None) T ¶
Clamps the elements of
self
toother
if they’re inferior to that value.- Parameters:
other (TensorDict or Tensor) – the other input tensordict or tensor.
- Keyword Arguments:
default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- clamp_min_(other: tensordict.base.TensorDictBase | torch.Tensor) T ¶
In-place version of
clamp_min()
.Note
Inplace
clamp_min
does not supportdefault
keyword argument.
- clear() T ¶
Erases the content of the tensordict.
- clear_device_() T ¶
Clears the device of the tensordict.
Returns: self
- clone(recurse: bool = True, **kwargs) T ¶
Clones a TensorDictBase subclass instance onto a new TensorDictBase subclass of the same type.
To create a TensorDict instance from any other TensorDictBase subtype, call the
to_tensordict()
method instead.- Parameters:
recurse (bool, optional) – if
True
, each tensor contained in the TensorDict will be copied too. Otherwise only the TensorDict tree structure will be copied. Defaults toTrue
.
Note
Unlike many other ops (pointwise arithmetic, shape operations, …)
clone
does not inherit the original lock attribute. This design choice is made such that a clone can be created to be modified, which is the most frequent usage.
- complex128()¶
Casts all tensors to
torch.complex128
.
- complex32()¶
Casts all tensors to
torch.complex32
.
- complex64()¶
Casts all tensors to
torch.complex64
.
- consolidate(filename: Optional[Union[Path, str]] = None, *, num_threads=0, device: Optional[device] = None, non_blocking: bool = False, inplace: bool = False, return_early: bool = False, use_buffer: bool = False, share_memory: bool = False, pin_memory: bool = False, metadata: bool = False) None ¶
Consolidates the tensordict content in a single storage for fast serialization.
- Parameters:
filename (Path, optional) – an optional file path for a memory-mapped tensor to use as a storage for the tensordict.
- Keyword Arguments:
num_threads (integer, optional) – the number of threads to use for populating the storage.
device (torch.device, optional) – an optional device where the storage must be instantiated.
non_blocking (bool, optional) –
non_blocking
argument passed tocopy_()
.inplace (bool, optional) – if
True
, the resulting tensordict is the same asself
with updated values. Defaults toFalse
.return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().use_buffer (bool, optional) – if
True
and a filename is passed, an intermediate local buffer will be created in shared memory, and the data will be copied at the storage location as a last step. This may be faster than writing directly to a distant physical memory (e.g., NFS). Defaults toFalse
.share_memory (bool, optional) – if
True
, the storage will be placed in shared memory. Defaults toFalse
.pin_memory (bool, optional) – whether the consolidated data should be placed in pinned memory. Defaults to
False
.metadata (bool, optional) – if
True
, the metadata will be stored alongisde the common storage. If a filename is provided, this is without effect. Storing the metadata can be useful when one wants to control how serialization is achieved, as TensorDict handles the pickling/unpickling of consolidated TDs differently if the metadata is or isn’t available.
Note
If the tensordict is already consolidated, all arguments are ignored and
self
is returned. Callcontiguous()
to re-consolidate.Examples
>>> import pickle >>> import tempfile >>> import torch >>> import tqdm >>> from torch.utils.benchmark import Timer >>> from tensordict import TensorDict >>> data = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> data_consolidated = data.consolidate() >>> # check that the data has a single data_ptr() >>> assert torch.tensor([ ... v.untyped_storage().data_ptr() for v in data_c.values(True, True) ... ]).unique().numel() == 1 >>> # Serializing the tensordict will be faster with data_consolidated >>> with open("data.pickle", "wb") as f: ... print("regular", Timer("pickle.dump(data, f)", globals=globals()).adaptive_autorange()) >>> with open("data_c.pickle", "wb") as f: ... print("consolidated", Timer("pickle.dump(data_consolidated, f)", globals=globals()).adaptive_autorange())
- contiguous() T ¶
Returns a new tensordict of the same type with contiguous values (or self if values are already contiguous).
- copy()¶
Return a shallow copy of the tensordict (ie, copies the structure but not the data).
Equivalent to TensorDictBase.clone(recurse=False)
- copy_(tensordict: T, non_blocking: bool = False) T ¶
-
The non-blocking argument will be ignored and is just present for compatibility with
torch.Tensor.copy_()
.
- copy_at_(tensordict: T, idx: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], non_blocking: bool = False) T ¶
- cos() T ¶
Computes the
cos()
value of each element of the TensorDict.
- cos_() T ¶
Computes the
cos()
value of each element of the TensorDict in-place.
- cosh() T ¶
Computes the
cosh()
value of each element of the TensorDict.
- cosh_() T ¶
Computes the
cosh()
value of each element of the TensorDict in-place.
- cpu(**kwargs) T ¶
Casts a tensordict to CPU.
This function also supports all the keyword arguments of
to()
.
- create_nested(key)¶
Creates a nested tensordict of the same shape, device and dim names as the current tensordict.
If the value already exists, it will be overwritten by this operation. This operation is blocked in locked tensordicts.
Examples
>>> data = TensorDict({}, [3, 4, 5]) >>> data.create_nested("root") >>> data.create_nested(("some", "nested", "value")) >>> print(data) TensorDict( fields={ root: TensorDict( fields={ }, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ value: TensorDict( fields={ }, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 4, 5]), device=None, is_shared=False)
- cuda(device: Optional[int] = None, **kwargs) T ¶
Casts a tensordict to a cuda device (if not already on it).
- Parameters:
device (int, optional) – if provided, the cuda device on which the tensor should be cast.
This function also supports all the keyword arguments of
to()
.
- property data¶
Returns a tensordict containing the .data attributes of the leaf tensors.
- data_ptr(*, storage: bool = False)¶
Returns the data_ptr of the tensordict leaves.
This can be useful to check if two tensordicts share the same
data_ptr()
.- Keyword Arguments:
storage (bool, optional) – if
True
, tensor.untyped_storage().data_ptr() will be called instead. Defaults toFalse
.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict(a=torch.randn(2), b=torch.randn(2), batch_size=[2]) >>> assert (td0.data_ptr() == td.data_ptr()).all()
Note
LazyStackedTensorDict
instances will be displayed as nested tensordicts to reflect the truedata_ptr()
of their leaves:>>> td0 = TensorDict(a=torch.randn(2), b=torch.randn(2), batch_size=[2]) >>> td1 = TensorDict(a=torch.randn(2), b=torch.randn(2), batch_size=[2]) >>> td = TensorDict.lazy_stack([td0, td1]) >>> td.data_ptr() TensorDict( fields={ 0: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), 1: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- del_(key: NestedKey) T ¶
Deletes a key of the tensordict.
- Parameters:
key (NestedKey) – key to be deleted
- Returns:
self
- densify(layout: layout = torch.strided)¶
Attempts to represent the lazy stack with contiguous tensors (plain tensors or nested).
- Keyword Arguments:
layout (torch.layout) – the layout of the nested tensors, if any. Defaults to
strided
.
- property depth: int¶
Returns the depth - maximum number of levels - of a tensordict.
The minimum depth is 0 (no nested tensordict).
- detach() T ¶
Detach the tensors in the tensordict.
- Returns:
a new tensordict with no tensor requiring gradient.
- detach_() T ¶
Detach the tensors in the tensordict in-place.
- Returns:
self.
- property device: torch.device | None¶
Device of the tensordict.
Returns None if device hasn’t been provided in the constructor or set via tensordict.to(device).
- dim() int ¶
See
batch_dims()
.
- div(other: tensordict.base.TensorDictBase | torch.Tensor, *, default: Optional[Union[str, Tensor]] = None) T ¶
Divides each element of the input
self
by the corresponding element ofother
.\[\text{out}_i = \frac{\text{input}_i}{\text{other}_i}\]Supports broadcasting, type promotion and integer, float, tensordict or tensor inputs. Always promotes integer types to the default scalar type.
- Parameters:
other (TensorDict, Tensor or Number) – the divisor.
- Keyword Arguments:
default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- div_(other: tensordict.base.TensorDictBase | torch.Tensor) T ¶
In-place version of
div()
.Note
Inplace
div
does not supportdefault
keyword argument.
- double()¶
Casts all tensors to
torch.bool
.
- property dtype¶
Returns the dtype of the values in the tensordict, if it is unique.
- dumps(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
Saves the tensordict to disk.
This function is a proxy to
memmap()
.
- empty(recurse=False, *, batch_size=None, device=_NoDefault.ZERO, names=_NoDefault.ZERO) T ¶
Returns a new, empty tensordict with the same device and batch size.
- Parameters:
recurse (bool, optional) – if
True
, the entire structure of theTensorDict
will be reproduced without content. Otherwise, only the root will be duplicated. Defaults toFalse
.- Keyword Arguments:
batch_size (torch.Size, optional) – a new batch-size for the tensordict.
device (torch.device, optional) – a new device.
names (list of str, optional) – dimension names.
- entry_class(key: NestedKey) type ¶
Returns the class of an entry, possibly avoiding a call to isinstance(td.get(key), type).
This method should be preferred to
tensordict.get(key).shape
wheneverget()
can be expensive to execute.
- erf() T ¶
Computes the
erf()
value of each element of the TensorDict.
- erf_() T ¶
Computes the
erf()
value of each element of the TensorDict in-place.
- erfc() T ¶
Computes the
erfc()
value of each element of the TensorDict.
- erfc_() T ¶
Computes the
erfc()
value of each element of the TensorDict in-place.
- exclude(*keys: NestedKey, inplace: bool = False) T ¶
Excludes the keys of the tensordict and returns a new tensordict without these entries.
The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.
- Parameters:
- Returns:
A new tensordict (or the same if
inplace=True
) without the excluded entries.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, []) >>> td.exclude("a", ("b", "c")) TensorDict( fields={ b: TensorDict( fields={ d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> td.exclude("a", "b") TensorDict( fields={ }, batch_size=torch.Size([]), device=None, is_shared=False)
- exp() T ¶
Computes the
exp()
value of each element of the TensorDict.
- exp_() T ¶
Computes the
exp()
value of each element of the TensorDict in-place.
- expand(*args, **kwargs) T ¶
Expands each tensor of the tensordict according to the
expand()
function, ignoring the feature dimensions.Supports iterables to specify the shape.
Examples
>>> td = TensorDict({ ... 'a': torch.zeros(3, 4, 5), ... 'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4]) >>> td_expand = td.expand(10, 3, 4) >>> assert td_expand.shape == torch.Size([10, 3, 4]) >>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5])
- expand_as(other: tensordict.base.TensorDictBase | torch.Tensor) TensorDictBase ¶
Broadcasts the shape of the tensordict to the shape of other and expands it accordingly.
If the input is a tensor collection (tensordict or tensorclass), the leaves will be expanded on a one-to-one basis.
Examples
>>> from tensordict import TensorDict >>> import torch >>> td0 = TensorDict({ ... "a": torch.ones(3, 1, 4), ... "b": {"c": torch.ones(3, 2, 1, 4)}}, ... batch_size=[3], ... ) >>> td1 = TensorDict({ ... "a": torch.zeros(2, 3, 5, 4), ... "b": {"c": torch.zeros(2, 3, 2, 6, 4)}}, ... batch_size=[2, 3], ... ) >>> expanded = td0.expand_as(td1) >>> assert (expanded==1).all() >>> print(expanded) TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3, 5, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([2, 3, 2, 6, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)
- expm1() T ¶
Computes the
expm1()
value of each element of the TensorDict.
- expm1_() T ¶
Computes the
expm1()
value of each element of the TensorDict in-place.
- fill_(key: NestedKey, value: float | bool) T ¶
Fills a tensor pointed by the key with a given scalar value.
- filter_empty_()¶
Filters out all empty tensordicts in-place.
- filter_non_tensor_data() T ¶
Filters out all non-tensor-data.
- flatten(start_dim=0, end_dim=- 1)¶
Flattens all the tensors of a tensordict.
Examples
>>> td = TensorDict({ ... "a": torch.arange(60).view(3, 4, 5), ... "b": torch.arange(12).view(3, 4)}, batch_size=[3, 4]) >>> td_flat = td.flatten(0, 1) >>> td_flat.batch_size torch.Size([12]) >>> td_flat["a"] tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29], [30, 31, 32, 33, 34], [35, 36, 37, 38, 39], [40, 41, 42, 43, 44], [45, 46, 47, 48, 49], [50, 51, 52, 53, 54], [55, 56, 57, 58, 59]]) >>> td_flat["b"] tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
- flatten_keys(separator: str = '.', inplace: bool = False, is_leaf: Optional[Callable[[Type], bool]] = None) T ¶
Converts a nested tensordict into a flat one, recursively.
The TensorDict type will be lost and the result will be a simple TensorDict instance.
- Parameters:
separator (str, optional) – the separator between the nested items.
inplace (bool, optional) – if
True
, the resulting tensordict will have the same identity as the one where the call has been made. Defaults toFalse
.is_leaf (callable, optional) – a callable over a class type returning a bool indicating if this class has to be considered as a leaf.
Examples
>>> data = TensorDict({"a": 1, ("b", "c"): 2, ("e", "f", "g"): 3}, batch_size=[]) >>> data.flatten_keys(separator=" - ") TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b - c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), e - f - g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This method and
unflatten_keys()
are particularly useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.Examples
>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4)) >>> ddp_model = torch.ao.quantization.QuantWrapper(model) >>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".") >>> print(state_dict) TensorDict( fields={ module: TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model_state_dict = state_dict.get("module") >>> print(model_state_dict) TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
- float()¶
Casts all tensors to
torch.float
.
- float16()¶
Casts all tensors to
torch.float16
.
- float32()¶
Casts all tensors to
torch.float32
.
- float64()¶
Casts all tensors to
torch.float64
.
- floor() T ¶
Computes the
floor()
value of each element of the TensorDict.
- floor_() T ¶
Computes the
floor()
value of each element of the TensorDict in-place.
- frac() T ¶
Computes the
frac()
value of each element of the TensorDict.
- frac_() T ¶
Computes the
frac()
value of each element of the TensorDict in-place.
- classmethod from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None)¶
Returns a TensorDict created from a dictionary or another
TensorDict
.If
batch_size
is not specified, returns the maximum batch size possible.This function works on nested dictionaries too, or can be used to determine the batch-size of a nested tensordict.
- Parameters:
input_dict (dictionary, optional) – a dictionary to use as a data source (nested keys compatible).
batch_size (iterable of int, optional) – a batch size for the tensordict.
device (torch.device or compatible type, optional) – a device for the TensorDict.
batch_dims (int, optional) – the
batch_dims
(ie number of leading dimensions to be considered forbatch_size
). Exclusinve withbatch_size
. Note that this is the __maximum__ number of batch dims of the tensordict, a smaller number is tolerated.names (list of str, optional) – the dimension names of the tensordict.
Examples
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} >>> print(TensorDict.from_dict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # nested dict: the nested TensorDict can have a different batch-size >>> # as long as its leading dims match. >>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}} >>> print(TensorDict.from_dict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # we can also use this to work out the batch sie of a tensordict >>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, []) >>> print(TensorDict.from_dict(input_td)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- from_dict_instance(input_dict, batch_size=None, device=None, batch_dims=None, names=None)¶
Instance method version of
from_dict()
.Unlike
from_dict()
, this method will attempt to keep the tensordict types within the existing tree (for any existing leaf).Examples
>>> from tensordict import TensorDict, tensorclass >>> import torch >>> >>> @tensorclass >>> class MyClass: ... x: torch.Tensor ... y: int >>> >>> td = TensorDict({"a": torch.randn(()), "b": MyClass(x=torch.zeros(()), y=1)}) >>> print(td.from_dict_instance(td.to_dict())) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: MyClass( x=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(td.from_dict(td.to_dict())) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- classmethod from_h5(filename, mode='r')¶
Creates a PersistentTensorDict from a h5 file.
This function will automatically determine the batch-size for each nested tensordict.
- classmethod from_module(module: Module, as_module: bool = False, lock: bool = False, use_state_dict: bool = False, filter_empty: bool = True)¶
Copies the params and buffers of a module in a tensordict.
- Parameters:
module (nn.Module) – the module to get the parameters from.
as_module (bool, optional) – if
True
, aTensorDictParams
instance will be returned which can be used to store parameters within atorch.nn.Module
. Defaults toFalse
.lock (bool, optional) – if
True
, the resulting tensordict will be locked. Defaults toTrue
.use_state_dict (bool, optional) –
if
True
, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults toFalse
.Note
This is particularly useful when state-dict hooks have to be used.
Examples
>>> from torch import nn >>> module = nn.TransformerDecoder( ... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4), ... num_layers=1 ... ) >>> params = TensorDict.from_module(module) >>> print(params["layers", "0", "linear1"]) TensorDict( fields={ bias: Parameter(shape=torch.Size([2048]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2048, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- classmethod from_modules(*modules, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, expand_identical: bool = False)¶
Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.
- Parameters:
modules (sequence of nn.Module) – the modules to get the parameters from. If the modules differ in their structure, a lazy stack is needed (see the
lazy_stack
argument below).- Keyword Arguments:
as_module (bool, optional) – if
True
, aTensorDictParams
instance will be returned which can be used to store parameters within atorch.nn.Module
. Defaults toFalse
.lock (bool, optional) – if
True
, the resulting tensordict will be locked. Defaults toTrue
.use_state_dict (bool, optional) –
if
True
, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults toFalse
.Note
This is particularly useful when state-dict hooks have to be used.
lazy_stack (bool, optional) –
whether parameters should be densly or lazily stacked. Defaults to
False
(dense stack).Note
lazy_stack
andas_module
are exclusive features.Warning
There is a crucial difference between lazy and non-lazy outputs in that non-lazy output will reinstantiate parameters with the desired batch-size, while
lazy_stack
will just represent the parameters as lazily stacked. This means that whilst the original parameters can safely be passed to an optimizer whenlazy_stack=True
, the new parameters need to be passed when it is set toTrue
.Warning
Whilst it can be tempting to use a lazy stack to keep the orignal parameter references, remember that lazy stack perform a stack each time
get()
is called. This will require memory (N times the size of the parameters, more if a graph is built) and time to be computed. It also means that the optimizer(s) will contain more parameters, and operations likestep()
orzero_grad()
will take longer to be executed. In general,lazy_stack
should be reserved to very few use cases.expand_identical (bool, optional) – if
True
and the same parameter (same identity) is being stacked to itself, an expanded version of this parameter will be returned instead. This argument is ignored whenlazy_stack=True
.
Examples
>>> from torch import nn >>> from tensordict import TensorDict >>> torch.manual_seed(0) >>> empty_module = nn.Linear(3, 4, device="meta") >>> n_models = 2 >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] >>> params = TensorDict.from_modules(*modules) >>> print(params) TensorDict( fields={ bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) >>> # example of batch execution >>> def exec_module(params, x): ... with params.to_module(empty_module): ... return empty_module(x) >>> x = torch.randn(3) >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> # since lazy_stack = False, backprop leaves the original params untouched >>> y.sum().backward() >>> assert params["weight"].grad.norm() > 0 >>> assert modules[0].weight.grad is None
With
lazy_stack=True
, things are slightly different:>>> params = TensorDict.from_modules(*modules, lazy_stack=True) >>> print(params) LazyStackedTensorDict( fields={ bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> # example of batch execution >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> y.sum().backward() >>> assert modules[0].weight.grad is not None
- classmethod from_namedtuple(named_tuple, *, auto_batch_size: bool = False)¶
Converts a namedtuple to a TensorDict recursively.
- Keyword Arguments:
auto_batch_size (bool, optional) – if
True
, the batch size will be computed automatically. Defaults toFalse
.
Examples
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> nt = data.to_namedtuple() >>> print(nt) GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!')) >>> TensorDict.from_namedtuple(nt, auto_batch_size=True) TensorDict( fields={ a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None), a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- classmethod from_pytree(pytree, *, batch_size: Optional[Size] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None)¶
Converts a pytree to a TensorDict instance.
This method is designed to keep the pytree nested structure as much as possible.
Additional non-tensor keys are added to keep track of each level’s identity, providing a built-in pytree-to-tensordict bijective transform API.
Accepted classes currently include lists, tuples, named tuples and dict.
Note
For dictionaries, non-NestedKey keys are registered separately as
NonTensorData
instances.Note
Tensor-castable types (such as int, float or np.ndarray) will be converted to torch.Tensor instances. Note that this transformation is surjective: transforming back the tensordict to a pytree will not recover the original types.
Examples
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key >>> class WeirdLookingClass: ... pass ... >>> weird_key = WeirdLookingClass() >>> # Make a pytree with tuple, lists, dict and namedtuple >>> pytree = ( ... [torch.randint(10, (3,)), torch.zeros(2)], ... { ... "tensor": torch.randn( ... 2, ... ), ... "td": TensorDict({"one": 1}), ... weird_key: torch.randint(10, (2,)), ... "list": [1, 2, 3], ... }, ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, ... ) >>> # Build a TensorDict from that pytree >>> td = TensorDict.from_pytree(pytree) >>> # Recover the pytree >>> pytree_recon = td.to_pytree() >>> # Check that the leaves match >>> def check(v1, v2): >>> assert (v1 == v2).all() >>> >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) >>> assert weird_key in pytree_recon[1]
- classmethod from_struct_array(struct_array: ndarray, device: Optional[device] = None) T ¶
Converts a structured numpy array to a TensorDict.
The content of the resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy operation). Changing values of the structured numpy array in-place will affect the content of the TensorDict.
Examples
>>> x = np.array( ... [("Rex", 9, 81.0), ("Fido", 3, 27.0)], ... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")], ... ) >>> td = TensorDict.from_struct_array(x) >>> x_recon = td.to_struct_array() >>> assert (x_recon == x).all() >>> assert x_recon.shape == x.shape >>> # Try modifying x age field and check effect on td >>> x["age"] += 1 >>> assert (td["age"] == np.array([10, 4])).all()
- classmethod fromkeys(keys: List[NestedKey], value: Any = 0)¶
Creates a tensordict from a list of keys and a single value.
- Parameters:
keys (list of NestedKey) – An iterable specifying the keys of the new dictionary.
value (compatible type, optional) – The value for all keys. Defaults to
0
.
- gather(dim: int, index: Tensor, out: Optional[T] = None) T ¶
Gathers values along an axis specified by dim.
- Parameters:
dim (int) – the dimension along which collect the elements
index (torch.Tensor) – a long tensor which number of dimension matches the one of the tensordict with only one dimension differring between the two (the gathering dimension). Its elements refer to the index to be gathered along the required dimension.
out (TensorDictBase, optional) – a destination tensordict. It must have the same shape as the index.
Examples
>>> td = TensorDict( ... {"a": torch.randn(3, 4, 5), ... "b": TensorDict({"c": torch.zeros(3, 4, 5)}, [3, 4, 5])}, ... [3, 4]) >>> index = torch.randint(4, (3, 2)) >>> td_gather = td.gather(dim=1, index=index) >>> print(td_gather) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 2, 5]), device=None, is_shared=False)}, batch_size=torch.Size([3, 2]), device=None, is_shared=False)
Gather keeps the dimension names.
Examples
>>> td.names = ["a", "b"] >>> td_gather = td.gather(dim=1, index=index) >>> td_gather.names ["a", "b"]
- gather_and_stack(dst: int, group: 'dist.ProcessGroup' | None = None) T | None ¶
Gathers tensordicts from various workers and stacks them onto self in the destination worker.
- Parameters:
dst (int) – the rank of the destination worker where
gather_and_stack()
will be called.group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to
None
.
Example
>>> from torch import multiprocessing as mp >>> from tensordict import TensorDict >>> import torch >>> >>> def client(): ... torch.distributed.init_process_group( ... "gloo", ... rank=1, ... world_size=2, ... init_method=f"tcp://localhost:10003", ... ) ... # Create a single tensordict to be sent to server ... td = TensorDict( ... {("a", "b"): torch.randn(2), ... "c": torch.randn(2)}, [2] ... ) ... td.gather_and_stack(0) ... >>> def server(): ... torch.distributed.init_process_group( ... "gloo", ... rank=0, ... world_size=2, ... init_method=f"tcp://localhost:10003", ... ) ... # Creates the destination tensordict on server. ... # The first dim must be equal to world_size-1 ... td = TensorDict( ... {("a", "b"): torch.zeros(2), ... "c": torch.zeros(2)}, [2] ... ).expand(1, 2).contiguous() ... td.gather_and_stack(0) ... assert td["a", "b"] != 0 ... print("yuppie") ... >>> if __name__ == "__main__": ... mp.set_start_method("spawn") ... ... main_worker = mp.Process(target=server) ... secondary_worker = mp.Process(target=client) ... ... main_worker.start() ... secondary_worker.start() ... ... main_worker.join() ... secondary_worker.join()
- get(key: NestedKey, *args, **kwargs) Tensor ¶
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default –
default value if the key is not found in the tensordict.
Warning
Currently, if a key is not present in the tensordict and no default is passed, a KeyError is raised. From v0.7, this behaviour will be changed and a None value will be returned instead. To adopt the new behaviour, set the environment variable export TD_GET_DEFAULTS_TO_NONE=’1’ or call :func`~tensordict.set_get_defaults_to_none`.
Examples
>>> td = TensorDict({"x": 1}, batch_size=[]) >>> td.get("x") tensor(1) >>> set_get_defaults_to_none(False) # Current default behaviour >>> td.get("y") # Raises KeyError >>> set_get_defaults_to_none(True) >>> td.get("y") None
- get_at(key: NestedKey, index: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], default: Tensor = _NoDefault.ZERO) Tensor ¶
Get the value of a tensordict from the key key at the index idx.
- Parameters:
key (str, tuple of str) – key to be retrieved.
index (int, slice, torch.Tensor, iterable) – index of the tensor.
default (torch.Tensor) – default value to return if the key is not present in the tensordict.
- Returns:
indexed tensor.
Examples
>>> td = TensorDict({"x": torch.arange(3)}, batch_size=[]) >>> td.get_at("x", index=1) tensor(1)
- get_item_shape(key: NestedKey)¶
Returns the shape of the entry, possibly avoiding recurring to
get()
.
- get_non_tensor(key: NestedKey, default=_NoDefault.ZERO)¶
Gets a non-tensor value, if it exists, or default if the non-tensor value is not found.
This method is robust to tensor/TensorDict values, meaning that if the value gathered is a regular tensor it will be returned too (although this method comes with some overhead and should not be used out of its natural scope).
See
set_non_tensor()
for more information on how to set non-tensor values in a tensordict.- Parameters:
key (NestedKey) – the location of the NonTensorData object.
default (Any, optional) – the value to be returned if the key cannot be found.
- Returns: the content of the
tensordict.tensorclass.NonTensorData
, or the entry corresponding to the
key
if it isn’t atensordict.tensorclass.NonTensorData
(ordefault
if the entry cannot be found).
Examples
>>> data = TensorDict({}, batch_size=[]) >>> data.set_non_tensor(("nested", "the string"), "a string!") >>> assert data.get_non_tensor(("nested", "the string")) == "a string!" >>> # regular `get` works but returns a NonTensorData object >>> data.get(("nested", "the string")) NonTensorData( data='a string!', batch_size=torch.Size([]), device=None, is_shared=False)
- property grad¶
Returns a tensordict containing the .grad attributes of the leaf tensors.
- half()¶
Casts all tensors to
torch.half
.
- int()¶
Casts all tensors to
torch.int
.
- int16()¶
Casts all tensors to
torch.int16
.
- int32()¶
Casts all tensors to
torch.int32
.
- int64()¶
Casts all tensors to
torch.int64
.
- int8()¶
Casts all tensors to
torch.int8
.
- irecv(src: int, *, group: 'dist.ProcessGroup' | None = None, return_premature: bool = False, init_tag: int = 0, pseudo_rand: bool = False) tuple[int, list[torch.Future]] | list[torch.Future] | None ¶
Receives the content of a tensordict and updates content with it asynchronously.
Check the example in the
isend()
method for context.- Parameters:
src (int) – the rank of the source worker.
- Keyword Arguments:
group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to
None
.return_premature (bool) – if
True
, returns a list of futures to wait upon until the tensordict is updated. Defaults toFalse
, i.e. waits until update is completed withing the call.init_tag (int) – the
init_tag
used by the source worker.pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to
isend()
. Defaults toFalse
.
- Returns:
- if
return_premature=True
, a list of futures to wait upon until the tensordict is updated.
- if
- is_consolidated()¶
Checks if a TensorDict has a consolidated storage.
- is_empty()¶
Checks if the tensordict contains any leaf.
- is_memmap() bool ¶
Checks if tensordict is memory-mapped.
If a TensorDict instance is memory-mapped, it is locked (entries cannot be renamed, removed or added). If a
TensorDict
is created with tensors that are all memory-mapped, this does __not__ mean thatis_memmap
will returnTrue
(as a new tensor may or may not be memory-mapped). Only if one calls tensordict.memmap_() will the tensordict be considered as memory-mapped.This is always
True
for tensordicts on a CUDA device.
Checks if tensordict is in shared memory.
If a TensorDict instance is in shared memory, it is locked (entries cannot be renamed, removed or added). If a
TensorDict
is created with tensors that are all in shared memory, this does __not__ mean thatis_shared
will returnTrue
(as a new tensor may or may not be in shared memory). Only if one calls tensordict.share_memory_() or places the tensordict on a device where the content is shared by default (eg,"cuda"
) will the tensordict be considered in shared memory.This is always
True
for tensordicts on a CUDA device.
- isend(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int ¶
Sends the content of the tensordict asynchronously.
- Parameters:
dst (int) – the rank of the destination worker where the content should be sent.
- Keyword Arguments:
group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to
None
.init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.
pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to
False
.
Example
>>> import torch >>> from tensordict import TensorDict >>> from torch import multiprocessing as mp >>> def client(): ... torch.distributed.init_process_group( ... "gloo", ... rank=1, ... world_size=2, ... init_method=f"tcp://localhost:10003", ... ) ... ... td = TensorDict( ... { ... ("a", "b"): torch.randn(2), ... "c": torch.randn(2, 3), ... "_": torch.ones(2, 1, 5), ... }, ... [2], ... ) ... td.isend(0) ... >>> >>> def server(queue, return_premature=True): ... torch.distributed.init_process_group( ... "gloo", ... rank=0, ... world_size=2, ... init_method=f"tcp://localhost:10003", ... ) ... td = TensorDict( ... { ... ("a", "b"): torch.zeros(2), ... "c": torch.zeros(2, 3), ... "_": torch.zeros(2, 1, 5), ... }, ... [2], ... ) ... out = td.irecv(1, return_premature=return_premature) ... if return_premature: ... for fut in out: ... fut.wait() ... assert (td != 0).all() ... queue.put("yuppie") ... >>> >>> if __name__ == "__main__": ... queue = mp.Queue(1) ... main_worker = mp.Process( ... target=server, ... args=(queue, ) ... ) ... secondary_worker = mp.Process(target=client) ... ... main_worker.start() ... secondary_worker.start() ... out = queue.get(timeout=10) ... assert out == "yuppie" ... main_worker.join() ... secondary_worker.join()
- isfinite() T ¶
Returns a new tensordict with boolean elements representing if each element is finite or not.
Real values are finite when they are not NaN, negative infinity, or infinity. Complex values are finite when both their real and imaginary parts are finite.
- isnan() T ¶
Returns a new tensordict with boolean elements representing if each element of input is NaN or not.
Complex values are considered NaN when either their real and/or imaginary part is NaN.
- isneginf() T ¶
Tests if each element of input is negative infinity or not.
- isposinf() T ¶
Tests if each element of input is negative infinity or not.
- isreal() T ¶
Returns a new tensordict with boolean elements representing if each element of input is real-valued or not.
- items(include_nested: bool = False, leaves_only: bool = False, is_leaf: Optional[Callable[[Type], bool]] = None, *, sort: bool = False) Iterator[tuple[str, torch.Tensor]] ¶
Returns a generator of key-value pairs for the tensordict.
- Parameters:
- Keyword Arguments:
sort (bool, optional) – whether the keys should be sorted. For nested keys, the keys are sorted according to their joined name (ie,
("a", "key")
will be counted as"a.key"
for sorting). Be mindful that sorting may incur significant overhead when dealing with large tensordicts. Defaults toFalse
.
- keys(include_nested: bool = False, leaves_only: bool = False, is_leaf: Optional[Callable[[Type], bool]] = None, *, sort: bool = False) _TensorDictKeysView ¶
Returns a generator of tensordict keys.
Warning
TensorDict
keys()
method returns a lazy view of the keys. If thekeys
are queried but not iterated over and then the tensordict is modified, iterating over the keys later will return the new configuration of the keys.- Parameters:
- Keyword Arguments:
sort (bool, optional) – whether the keys shoulbe sorted. For nested keys, the keys are sorted according to their joined name (ie,
("a", "key")
will be counted as"a.key"
for sorting). Be mindful that sorting may incur significant overhead when dealing with large tensordicts. Defaults toFalse
.
Examples
>>> from tensordict import TensorDict >>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[]) >>> data.keys() ['0', '1'] >>> list(data.keys(leaves_only=True)) ['0'] >>> list(data.keys(include_nested=True, leaves_only=True)) ['0', '1', ('1', '2')]
- classmethod lazy_stack(input, dim=0, *, out=None, **kwargs)¶
Creates a lazy stack of tensordicts.
See
lazy_stack()
for details.
- lerp(end: tensordict.base.TensorDictBase | torch.Tensor, weight: tensordict.base.TensorDictBase | torch.Tensor | float)¶
Does a linear interpolation of two tensors
start
(given byself
) andend
based on a scalar or tensorweight
.\[\text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i)\]The shapes of
start
andend
must be broadcastable. Ifweight
is a tensor, then the shapes ofweight
,start
, andend
must be broadcastable.- Parameters:
end (TensorDict) – the tensordict with the ending points.
weight (TensorDict, tensor or float) – the weight for the interpolation formula.
- lerp_(end: tensordict.base.TensorDictBase | float, weight: tensordict.base.TensorDictBase | float)¶
In-place version of
lerp()
.
- lgamma() T ¶
Computes the
lgamma()
value of each element of the TensorDict.
- lgamma_() T ¶
Computes the
lgamma()
value of each element of the TensorDict in-place.
- classmethod load(prefix: str | pathlib.Path, *args, **kwargs) T ¶
Loads a tensordict from disk.
This class method is a proxy to
load_memmap()
.
- load_(prefix: str | pathlib.Path, *args, **kwargs)¶
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_()
.
- classmethod load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None) T ¶
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True
, synchronize won’t be called after loading tensors on device. Defaults toFalse
.out (TensorDictBase, optional) – optional tensordict where the data should be written.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_memmap_(prefix: str | pathlib.Path)¶
Loads the content of a memory-mapped tensordict within the tensordict where
load_memmap_
is called.See
load_memmap()
for more info.
- load_state_dict(state_dict: OrderedDict[str, Any], strict=True, assign=False, from_flatten=False) T ¶
Loads a state-dict, formatted as in
state_dict()
, into the tensordict.- Parameters:
state_dict (OrderedDict) – the state_dict of to be copied.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this tensordict’storch.nn.Module.state_dict()
function. Default:True
assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the tensordict instead of copying them inplace into the tensordict’s current tensors. When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. Default:False
from_flatten (bool, optional) – if
True
, the input state_dict is assumed to be flattened. Defaults toFalse
.
Examples
>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, []) >>> sd = data.state_dict() >>> data_zeroed.load_state_dict(sd) >>> print(data_zeroed["3", "3"]) tensor(3) >>> # with flattening >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, []) >>> data_zeroed.load_state_dict(data.state_dict(flatten=True), from_flatten=True) >>> print(data_zeroed["3", "3"]) tensor(3)
- lock_() T ¶
Locks a tensordict for non in-place operations.
Functions such as
set()
,__setitem__()
,update()
,rename_key_()
or other operations that add or remove entries will be blocked.This method can be used as a decorator.
Example
>>> from tensordict import TensorDict >>> td = TensorDict({"a": 1, "b": 2, "c": 3}, batch_size=[]) >>> with td.lock_(): ... assert td.is_locked ... try: ... td.set("d", 0) # error! ... except RuntimeError: ... print("td is locked!") ... try: ... del td["d"] ... except RuntimeError: ... print("td is locked!") ... try: ... td.rename_key_("a", "d") ... except RuntimeError: ... print("td is locked!") ... td.set("a", 0, inplace=True) # No storage is added, moved or removed ... td.set_("a", 0) # No storage is added, moved or removed ... td.update({"a": 0}, inplace=True) # No storage is added, moved or removed ... td.update_({"a": 0}) # No storage is added, moved or removed >>> assert not td.is_locked
- log() T ¶
Computes the
log()
value of each element of the TensorDict.
- log10() T ¶
Computes the
log10()
value of each element of the TensorDict.
- log10_() T ¶
Computes the
log10()
value of each element of the TensorDict in-place.
- log1p() T ¶
Computes the
log1p()
value of each element of the TensorDict.
- log1p_() T ¶
Computes the
log1p()
value of each element of the TensorDict in-place.
- log2() T ¶
Computes the
log2()
value of each element of the TensorDict.
- log2_() T ¶
Computes the
log2()
value of each element of the TensorDict in-place.
- log_() T ¶
Computes the
log()
value of each element of the TensorDict in-place.
- make_memmap(key: NestedKey, shape: torch.Size | torch.Tensor, *, dtype: Optional[dtype] = None) MemoryMappedTensor ¶
Creates an empty memory-mapped tensor given a shape and possibly a dtype.
Warning
This method is not lock-safe by design. A memory-mapped TensorDict instance present on multiple nodes will need to be updated using the method
memmap_refresh_()
.Writing an existing entry will result in an error.
- Parameters:
key (NestedKey) – the key of the new entry to write. If the key is already present in the tensordict, an exception is raised.
shape (torch.Size or equivalent, torch.Tensor for nested tensors) – the shape of the tensor to write.
- Keyword Arguments:
dtype (torch.dtype, optional) – the dtype of the new tensor.
- Returns:
A new memory mapped tensor.
- make_memmap_from_storage(key: NestedKey, storage: UntypedStorage, shape: torch.Size | torch.Tensor, *, dtype: Optional[dtype] = None) MemoryMappedTensor ¶
Creates an empty memory-mapped tensor given a storage, a shape and possibly a dtype.
Warning
This method is not lock-safe by design. A memory-mapped TensorDict instance present on multiple nodes will need to be updated using the method
memmap_refresh_()
.Note
If the storage has a filename associated, it must match the new filename for the file. If it has not a filename associated but the tensordict has an associated path, this will result in an exception.
- Parameters:
key (NestedKey) – the key of the new entry to write. If the key is already present in the tensordict, an exception is raised.
storage (torch.UntypedStorage) – the storage to use for the new MemoryMappedTensor. Must be a physical memory storage.
shape (torch.Size or equivalent, torch.Tensor for nested tensors) – the shape of the tensor to write.
- Keyword Arguments:
dtype (torch.dtype, optional) – the dtype of the new tensor.
- Returns:
A new memory mapped tensor with the given storage.
- make_memmap_from_tensor(key: NestedKey, tensor: Tensor, *, copy_data: bool = True, existsok: bool = True) MemoryMappedTensor ¶
Creates an empty memory-mapped tensor given a tensor.
Warning
This method is not lock-safe by design. A memory-mapped TensorDict instance present on multiple nodes will need to be updated using the method
memmap_refresh_()
.This method always copies the storage content if
copy_data
isTrue
(i.e., the storage is not shared).- Parameters:
key (NestedKey) – the key of the new entry to write. If the key is already present in the tensordict, an exception is raised.
tensor (torch.Tensor) – the tensor to replicate on physical memory.
- Keyword Arguments:
copy_data (bool, optionaL) – if
False
, the new tensor will share the metadata of the input such as shape and dtype, but the content will be empty. Defaults toTrue
.- Returns:
A new memory mapped tensor with the given storage.
- map(fn: Callable[[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, out: TensorDictBase | None = None, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = False, pbar: bool = False, mp_start_method: str | None = None)¶
Maps a function to splits of the tensordict across one dimension.
This method will apply a function to a tensordict instance by chunking it in tensordicts of equal size and dispatching the operations over the desired number of workers.
The function signature should be
Callabe[[TensorDict], Union[TensorDict, Tensor]]
. The output must support thetorch.cat()
operation. The function must be serializable.Note
This method is particularly useful when working with large datasets stored on disk (e.g. memory-mapped tensordicts) where chunks will be zero-copied slices of the original data which can be passed to the processes with virtually zero-cost. This allows to tread very large datasets (eg. over a Tb big) to be processed at little cost.
- Parameters:
fn (callable) – function to apply to the tensordict. Signatures similar to
Callabe[[TensorDict], Union[TensorDict, Tensor]]
are supported.dim (int, optional) – the dim along which the tensordict will be chunked.
num_workers (int, optional) – the number of workers. Exclusive with
pool
. If none is provided, the number of workers will be set to the number of cpus available.
- Keyword Arguments:
out (TensorDictBase, optional) – an optional container for the output. Its batch-size along the
dim
provided must matchself.ndim
. If it is shared or memmap (is_shared()
oris_memmap()
returnsTrue
) it will be populated within the remote processes, avoiding data inward transfers. Otherwise, the data from theself
slice will be sent to the process, collected on the current process and written inplace intoout
.chunksize (int, optional) – The size of each chunk of data. A
chunksize
of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, whereaschunksize>0
will split the tensordict and calltorch.cat()
on the resulting list of tensordicts. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive withnum_chunks
.num_chunks (int, optional) – the number of chunks to split the tensordict into. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with
chunksize
.pool (mp.Pool, optional) – a multiprocess Pool instance to use to execute the job. If none is provided, a pool will be created within the
map
method.generator (torch.Generator, optional) –
a generator to use for seeding. A base seed will be generated from it, and each worker of the pool will be seeded with the provided seed incremented by a unique integer from
0
tonum_workers
. If no generator is provided, a random integer will be used as seed. To work with unseeded workers, a pool should be created separately and passed tomap()
directly.Note
Caution should be taken when providing a low-valued seed as this can cause autocorrelation between experiments, example: if 8 workers are asked and the seed is 4, the workers seed will range from 4 to 11. If the seed is 5, the workers seed will range from 5 to 12. These two experiments will have an overlap of 7 seeds, which can have unexpected effects on the results.
Note
The goal of seeding the workers is to have independent seed on each worker, and NOT to have reproducible results across calls of the map method. In other words, two experiments may and probably will return different results as it is impossible to know which worker will pick which job. However, we can make sure that each worker has a different seed and that the pseudo-random operations on each will be uncorrelated.
max_tasks_per_child (int, optional) – the maximum number of jobs picked by every child process. Defaults to
None
, i.e., no restriction on the number of jobs.worker_threads (int, optional) – the number of threads for the workers. Defaults to
1
.index_with_generator (bool, optional) – if
True
, the splitting / chunking of the tensordict will be done during the query, sparing init time. Note thatchunk()
andsplit()
are much more efficient than indexing (which is used within the generator) so a gain of processing time at init time may have a negative impact on the total runtime. Defaults toFalse
.pbar (bool, optional) – if
True
, a progress bar will be displayed. Requires tqdm to be available. Defaults toFalse
.mp_start_method (str, optional) – the start method for multiprocessing. If not provided, the default start method will be used. Accepted strings are
"fork"
and"spawn"
. Keep in mind that"cuda"
tensors cannot be shared between processes with the"fork"
start method. This is without effect if thepool
is passed to themap
method.
Examples
>>> import torch >>> from tensordict import TensorDict >>> >>> def process_data(data): ... data.set("y", data.get("x") + 1) ... return data >>> if __name__ == "__main__": ... data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_() ... data = data.map(process_data, dim=1) ... print(data["y"][:, :10]) ... tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
- map_iter(fn: Callable[[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, shuffle: bool = False, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = True, pbar: bool = False, mp_start_method: str | None = None)¶
Maps a function to splits of the tensordict across one dimension iteratively.
This is the iterable version of
map()
.This method will apply a function to a tensordict instance by chunking it in tensordicts of equal size and dispatching the operations over the desired number of workers. It will yield the results one at a time.
The function signature should be
Callabe[[TensorDict], Union[TensorDict, Tensor]]
. The function must be serializable.Note
This method is particularly useful when working with large datasets stored on disk (e.g. memory-mapped tensordicts) where chunks will be zero-copied slices of the original data which can be passed to the processes with virtually zero-cost. This allows to tread very large datasets (eg. over a Tb big) to be processed at little cost.
Note
This function be used to represent a dataset and load from it, in a dataloader-like fashion.
- Parameters:
fn (callable) – function to apply to the tensordict. Signatures similar to
Callabe[[TensorDict], Union[TensorDict, Tensor]]
are supported.dim (int, optional) – the dim along which the tensordict will be chunked.
num_workers (int, optional) – the number of workers. Exclusive with
pool
. If none is provided, the number of workers will be set to the number of cpus available.
- Keyword Arguments:
shuffle (bool, optional) – whether the indices should be globally shuffled. If
True
, each batch will contain non-contiguous samples. Ifindex_with_generator=False
and shuffle=True`, an error will be raised. Defaults toFalse
.chunksize (int, optional) – The size of each chunk of data. A
chunksize
of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, whereaschunksize>0
will split the tensordict and calltorch.cat()
on the resulting list of tensordicts. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive withnum_chunks
.num_chunks (int, optional) – the number of chunks to split the tensordict into. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with
chunksize
.pool (mp.Pool, optional) – a multiprocess Pool instance to use to execute the job. If none is provided, a pool will be created within the
map
method.generator (torch.Generator, optional) –
a generator to use for seeding. A base seed will be generated from it, and each worker of the pool will be seeded with the provided seed incremented by a unique integer from
0
tonum_workers
. If no generator is provided, a random integer will be used as seed. To work with unseeded workers, a pool should be created separately and passed tomap()
directly.Note
Caution should be taken when providing a low-valued seed as this can cause autocorrelation between experiments, example: if 8 workers are asked and the seed is 4, the workers seed will range from 4 to 11. If the seed is 5, the workers seed will range from 5 to 12. These two experiments will have an overlap of 7 seeds, which can have unexpected effects on the results.
Note
The goal of seeding the workers is to have independent seed on each worker, and NOT to have reproducible results across calls of the map method. In other words, two experiments may and probably will return different results as it is impossible to know which worker will pick which job. However, we can make sure that each worker has a different seed and that the pseudo-random operations on each will be uncorrelated.
max_tasks_per_child (int, optional) – the maximum number of jobs picked by every child process. Defaults to
None
, i.e., no restriction on the number of jobs.worker_threads (int, optional) – the number of threads for the workers. Defaults to
1
.index_with_generator (bool, optional) –
if
True
, the splitting / chunking of the tensordict will be done during the query, sparing init time. Note thatchunk()
andsplit()
are much more efficient than indexing (which is used within the generator) so a gain of processing time at init time may have a negative impact on the total runtime. Defaults toTrue
.Note
The default value of
index_with_generator
differs formap_iter
andmap
and the former assumes that it is prohibitively expensive to store a split version of the TensorDict in memory.pbar (bool, optional) – if
True
, a progress bar will be displayed. Requires tqdm to be available. Defaults toFalse
.mp_start_method (str, optional) – the start method for multiprocessing. If not provided, the default start method will be used. Accepted strings are
"fork"
and"spawn"
. Keep in mind that"cuda"
tensors cannot be shared between processes with the"fork"
start method. This is without effect if thepool
is passed to themap
method.
Examples
>>> import torch >>> from tensordict import TensorDict >>> >>> def process_data(data): ... data.unlock_() ... data.set("y", data.get("x") + 1) ... return data >>> if __name__ == "__main__": ... data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_() ... for sample in data.map_iter(process_data, dim=1, chunksize=5): ... print(sample["y"]) ... break ... tensor([[1., 1., 1., 1., 1.]])
- masked_fill(mask: Tensor, value: float | bool) T ¶
Out-of-place version of masked_fill.
- Parameters:
mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.
value – value to used to fill the tensors.
- Returns:
self
Examples
>>> td = TensorDict(source={'a': torch.zeros(3, 4)}, ... batch_size=[3]) >>> mask = torch.tensor([True, False, False]) >>> td1 = td.masked_fill(mask, 1.0) >>> td1.get("a") tensor([[1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]])
- masked_fill_(mask: Tensor, value: float | int | bool) T ¶
Fills the values corresponding to the mask with the desired value.
- Parameters:
mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.
value – value to used to fill the tensors.
- Returns:
self
Examples
>>> td = TensorDict(source={'a': torch.zeros(3, 4)}, ... batch_size=[3]) >>> mask = torch.tensor([True, False, False]) >>> td.masked_fill_(mask, 1.0) >>> td.get("a") tensor([[1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]])
- masked_select(mask: Tensor) T ¶
Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values.
- Parameters:
mask (torch.Tensor) – boolean mask to be used for the tensors. Shape must match the TensorDict
batch_size
.
Examples
>>> td = TensorDict(source={'a': torch.zeros(3, 4)}, ... batch_size=[3]) >>> mask = torch.tensor([True, False, False]) >>> td_mask = td.masked_select(mask) >>> td_mask.get("a") tensor([[0., 0., 0., 0.]])
- maximum(other: tensordict.base.TensorDictBase | torch.Tensor, *, default: Optional[Union[str, Tensor]] = None) T ¶
Computes the element-wise maximum of
self
andother
.- Parameters:
other (TensorDict or Tensor) – the other input tensordict or tensor.
- Keyword Arguments:
default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- maximum_(other: tensordict.base.TensorDictBase | torch.Tensor) T ¶
In-place version of
maximum()
.Note
Inplace
maximum
does not supportdefault
keyword argument.
- classmethod maybe_dense_stack(input, dim=0, *, out=None, **kwargs)¶
Attempts to make a dense stack of tensordicts, and falls back on lazy stack when required..
See
maybe_dense_stack()
for details.
- mean(dim: Union[int, Tuple[int]] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: Optional[dtype] = None, reduce: Optional[bool] = None) tensordict.base.TensorDictBase | torch.Tensor ¶
Returns the mean value of all elements in the input tensordict.
- Parameters:
dim (int, tuple of int, optional) – if
None
, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.keepdim (bool) – whether the output tensor has dim retained or not.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default:
None
.reduce (bool, optional) – if
True
, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults toFalse
.
- memmap(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: Optional[str] = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDict
instance with data stored as memory-mapped tensors ifreturn_early=False
, otherwise aTensorDictFuture
instance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()
will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
Refreshes the content of the memory-mapped tensordict if it has a
saved_path
.This method will raise an exception if no path is associated with it.
- minimum(other: tensordict.base.TensorDictBase | torch.Tensor, *, default: Optional[Union[str, Tensor]] = None) T ¶
Computes the element-wise minimum of
self
andother
.- Parameters:
other (TensorDict or Tensor) – the other input tensordict or tensor.
- Keyword Arguments:
default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- minimum_(other: tensordict.base.TensorDictBase | torch.Tensor) T ¶
In-place version of
minimum()
.Note
Inplace
minimum
does not supportdefault
keyword argument.
- mul(other: tensordict.base.TensorDictBase | torch.Tensor, *, default: Optional[Union[str, Tensor]] = None) T ¶
Multiplies
other
toself
.\[\text{{out}}_i = \text{{input}}_i \times \text{{other}}_i\]Supports broadcasting, type promotion, and integer, float, and complex inputs.
- Parameters:
other (TensorDict, Tensor or Number) – the tensor or number to subtract from
self
.- Keyword Arguments:
default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- mul_(other: tensordict.base.TensorDictBase | torch.Tensor) T ¶
In-place version of
mul()
.Note
Inplace
mul
does not supportdefault
keyword argument.
- named_apply(fn: Callable, *others: T, nested_keys: bool = False, batch_size: Optional[Sequence[int]] = None, device: torch.device | None = _NoDefault.ZERO, names: Optional[Sequence[str]] = _NoDefault.ZERO, inplace: bool = False, default: Any = _NoDefault.ZERO, filter_empty: Optional[bool] = None, propagate_lock: bool = False, call_on_nested: bool = False, out: Optional[TensorDictBase] = None, **constructor_kwargs) Optional[T] ¶
Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict.
The callable signature must be
Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]
.- Parameters:
fn (Callable) – function to be applied to the (name, tensor) pairs in the tensordict. For each leaf, only its leaf name will be used (not the full NestedKey).
*others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The
fn
argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through thedefault
keyword argument.nested_keys (bool, optional) – if
True
, the complete path to the leaf will be used. Defaults toFalse
, i.e. only the last string is passed to the function.batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The
batch_size
argument should match the batch_size after the transformation. This is a keyword only argument.device (torch.device, optional) – the resulting device, if any.
names (list of str, optional) – the new dimension names, in case the batch_size is modified.
inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.
default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.
filter_empty (bool, optional) – if
True
, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Defaults toFalse
for backward compatibility.propagate_lock (bool, optional) – if
True
, a locked tensordict will produce another locked tensordict. Defaults toFalse
.call_on_nested (bool, optional) –
if
True
, the function will be called on first-level tensors and containers (TensorDict or tensorclass). In this scenario,func
is responsible of propagating its calls to nested levels. This allows a fine-grained behaviour when propagating the calls to nested tensordicts. IfFalse
, the function will only be called on leaves, andapply
will take care of dispatching the function to all leaves.>>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]}) >>> def mean_tensor_only(val): ... if is_tensor_collection(val): ... raise RuntimeError("Unexpected!") ... return val.mean() >>> td_mean = td.apply(mean_tensor_only) >>> def mean_any(val): ... if is_tensor_collection(val): ... # Recurse ... return val.apply(mean_any, call_on_nested=True) ... return val.mean() >>> td_mean = td.apply(mean_any, call_on_nested=True)
out (TensorDictBase, optional) –
a tensordict where to write the results. This can be used to avoid creating a new tensordict:
>>> td = TensorDict({"a": 0}) >>> td.apply(lambda x: x+1, out=td) >>> assert (td==1).all()
Warning
If the operation executed on the tensordict requires multiple keys to be accessed for a single computation, providing an
out
argument equal toself
can cause the operation to provide silently wrong results. For instance:>>> td = TensorDict({"a": 1, "b": 1}) >>> td.apply(lambda x: x+td["a"])["b"] # Right! tensor(2) >>> td.apply(lambda x: x+td["a"], out=td)["b"] # Wrong! tensor(3)
**constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.
- Returns:
a new tensordict with transformed_in tensors.
Example
>>> td = TensorDict({ ... "a": -torch.ones(3), ... "nested": {"a": torch.ones(3), "b": torch.zeros(3)}}, ... batch_size=[3]) >>> def name_filter(name, tensor): ... if name == "a": ... return tensor >>> td.named_apply(name_filter) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> def name_filter(name, *tensors): ... if name == "a": ... r = 0 ... for tensor in tensors: ... r = r + tensor ... return tensor >>> out = td.named_apply(name_filter, td) >>> print(out) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> print(out["a"]) tensor([-1., -1., -1.])
Note
If
None
is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, []) >>> def name_filter(name, tensor): ... if name == "1": ... return tensor >>> td.named_apply(name_filter) TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- property names¶
The dimension names of the tensordict.
The names can be set at construction time using the
names
argument.See also
refine_names()
for details on how to set the names after construction.
- nanmean(dim: Union[int, Tuple[int]] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: Optional[dtype] = None, reduce: Optional[bool] = None) tensordict.base.TensorDictBase | torch.Tensor ¶
Returns the mean of all non-NaN elements in the input tensordict.
- Parameters:
dim (int, tuple of int, optional) – if
None
, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.keepdim (bool) – whether the output tensor has dim retained or not.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default:
None
.reduce (bool, optional) – if
True
, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults toFalse
.
- nansum(dim: Union[int, Tuple[int]] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: Optional[dtype] = None, reduce: Optional[bool] = None) tensordict.base.TensorDictBase | torch.Tensor ¶
Returns the sum of all non-NaN elements in the input tensordict.
- Parameters:
dim (int, tuple of int, optional) – if
None
, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.keepdim (bool) – whether the output tensor has dim retained or not.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default:
None
.reduce (bool, optional) – if
True
, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults toFalse
.
- property ndim: int¶
See
batch_dims()
.
- ndimension() int ¶
See
batch_dims()
.
- neg() T ¶
Computes the
neg()
value of each element of the TensorDict.
- neg_() T ¶
Computes the
neg()
value of each element of the TensorDict in-place.
- new_empty(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
Returns a TensorDict of size
size
with emtpy tensors.By default, the returned TensorDict has the same
torch.dtype
andtorch.device
as this tensordict.- Parameters:
size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if
None
, the torch.dtype will be unchanged.device (torch.device, optional) – the desired device of returned tensordict. Default: if
None
, thetorch.device
will be unchanged.requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default:
False
.layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default:
torch.strided
.pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default:
False
.
- new_full(size: Size, fill_value, *, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
Returns a TensorDict of size
size
filled with 1.By default, the returned TensorDict has the same
torch.dtype
andtorch.device
as this tensordict.- Parameters:
size (sequence of int) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.
fill_value (scalar) – the number to fill the output tensor with.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if
None
, the torch.dtype will be unchanged.device (torch.device, optional) – the desired device of returned tensordict. Default: if
None
, thetorch.device
will be unchanged.requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default:
False
.layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default:
torch.strided
.pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default:
False
.
- new_ones(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
Returns a TensorDict of size
size
filled with 1.By default, the returned TensorDict has the same
torch.dtype
andtorch.device
as this tensordict.- Parameters:
size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if
None
, the torch.dtype will be unchanged.device (torch.device, optional) – the desired device of returned tensordict. Default: if
None
, thetorch.device
will be unchanged.requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default:
False
.layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default:
torch.strided
.pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default:
False
.
- new_tensor(data: torch.Tensor | tensordict.base.TensorDictBase, *, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, pin_memory: Optional[bool] = None)¶
Returns a new TensorDict with data as the tensor
data
.By default, the returned TensorDict values have the same
torch.dtype
andtorch.device
as this tensor.The
data
can also be a tensor collection (TensorDict
ortensorclass
), in which case thenew_tensor
method iterates over the tensor pairs ofself
anddata
.- Parameters:
data (torch.Tensor or TensorDictBase) – the data to be copied.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if
None
, the torch.dtype will be unchanged.device (torch.device, optional) – the desired device of returned tensordict. Default: if
None
, thetorch.device
will be unchanged.requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default:
False
.pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default:
False
.
- new_zeros(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)¶
Returns a TensorDict of size
size
filled with 0.By default, the returned TensorDict has the same
torch.dtype
andtorch.device
as this tensordict.- Parameters:
size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if
None
, the torch.dtype will be unchanged.device (torch.device, optional) – the desired device of returned tensordict. Default: if
None
, thetorch.device
will be unchanged.requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default:
False
.layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default:
torch.strided
.pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default:
False
.
- norm(*, out=None, dtype: torch.dtype | None = None)¶
Computes the norm of each tensor in the tensordict.
- Keyword Arguments:
out (TensorDict, optional) – the output tensordict.
dtype (torch.dtype, optional) – the output dtype (torch>=2.4).
- numel() int ¶
Total number of elements in the batch.
Lower-bounded to 1, as a stack of two tensordict with empty shape will have two elements, therefore we consider that a tensordict is at least 1-element big.
- numpy()¶
Converts a tensordict to a (possibly nested) dictionary of numpy arrays.
Non-tensor data is exposed as such.
Examples
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({"a": {"b": torch.zeros(()), "c": "a string!"}}) >>> print(data) TensorDict( fields={ a: TensorDict( fields={ b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(data.numpy()) {'a': {'b': array(0., dtype=float32), 'c': 'a string!'}}
- param_count(*, count_duplicates: bool = True) int ¶
Counts the number of parameters (total number of indexable items), accounting for tensors only.
- Keyword Arguments:
count_duplicates (bool) – Whether to count duplicated tensor as independent or not. If
False
, only strictly identical tensors will be discarded (same views but different ids from a common base tensor will be counted twice). Defaults to True (each tensor is assumed to be a single copy).
- permute(*args, **kwargs)¶
Returns a view of a tensordict with the batch dimensions permuted according to dims.
- Parameters:
*dims_list (int) – the new ordering of the batch dims of the tensordict. Alternatively, a single iterable of integers can be provided.
dims (list of int) – alternative way of calling permute(…).
- Returns:
a new tensordict with the batch dimensions in the desired order.
Examples
>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4]) >>> print(tensordict.permute([1, 0])) PermutedTensorDict( source=TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=cpu, is_shared=False), op=permute(dims=[1, 0])) >>> print(tensordict.permute(1, 0)) PermutedTensorDict( source=TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=cpu, is_shared=False), op=permute(dims=[1, 0])) >>> print(tensordict.permute(dims=[1, 0])) PermutedTensorDict( source=TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=cpu, is_shared=False), op=permute(dims=[1, 0]))
- pin_memory(num_threads: Optional[int] = None, inplace: bool = False) T ¶
Calls
pin_memory()
on the stored tensors.- Parameters:
num_threads (int or str) – if provided, the number of threads to use to call
pin_memory
on the leaves. Defaults toNone
, which sets a high number of threads inThreadPoolExecutor(max_workers=None)
. To execute all the calls topin_memory()
on the main thread, passnum_threads=0
.inplace (bool, optional) – if
True
, the tensordict is modified in-place. Defaults toFalse
.
- pin_memory_(num_threads: int | str = 0) T ¶
Calls
pin_memory()
on the stored tensors and returns the TensorDict modifies in-place.
- pop(key: NestedKey, default: Any = _NoDefault.ZERO) Tensor ¶
Removes and returns a value from a tensordict.
If the value is not present and no default value is provided, a KeyError is thrown.
- Parameters:
key (str or nested key) – the entry to look for.
default (Any, optional) – the value to return if the key cannot be found.
Examples
>>> td = TensorDict({"1": 1}, []) >>> one = td.pop("1") >>> assert one == 1 >>> none = td.pop("1", default=None) >>> assert none is None
- popitem() Tuple[NestedKey, Tensor] ¶
Removes the item that was last inserted into the TensorDict.
popitem
will only return non-nested values.
- pow(other: tensordict.base.TensorDictBase | torch.Tensor, *, default: Optional[Union[str, Tensor]] = None) T ¶
Takes the power of each element in
self
withother
and returns a tensor with the result.other
can be either a singlefloat
number, a Tensor or aTensorDict
.When
other
is a tensor, the shapes ofinput
andother
must be broadcastable.- Parameters:
other (float, tensor or tensordict) – the exponent value
- Keyword Arguments:
default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- pow_(other: tensordict.base.TensorDictBase | torch.Tensor) T ¶
In-place version of
pow()
.Note
Inplace
pow
does not supportdefault
keyword argument.
- prod(dim: Union[int, Tuple[int]] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: Optional[dtype] = None, reduce: Optional[bool] = None) tensordict.base.TensorDictBase | torch.Tensor ¶
Returns the produce of values of all elements in the input tensordict.
- Parameters:
dim (int, tuple of int, optional) – if
None
, returns a dimensionless tensordict containing the prod value of all leaves (if this can be computed). If integer or tuple of integers, prod is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.keepdim (bool) – whether the output tensor has dim retained or not.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default:
None
.reduce (bool, optional) – if
True
, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults toFalse
.
- qint32()¶
Casts all tensors to
torch.qint32
.
- qint8()¶
Casts all tensors to
torch.qint8
.
- quint4x2()¶
Casts all tensors to
torch.quint4x2
.
- quint8()¶
Casts all tensors to
torch.quint8
.
- reciprocal() T ¶
Computes the
reciprocal()
value of each element of the TensorDict.
- reciprocal_() T ¶
Computes the
reciprocal()
value of each element of the TensorDict in-place.
- record_stream(stream: Stream)¶
Marks the tensordict as having been used by this stream.
When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work queued on stream at the time of deallocation is complete.
See
record_stream()
for more information.`
- recv(src: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int ¶
Receives the content of a tensordict and updates content with it.
Check the example in the send method for context.
- Parameters:
src (int) – the rank of the source worker.
- Keyword Arguments:
group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to
None
.init_tag (int) – the
init_tag
used by the source worker.pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to
send()
. Defaults toFalse
.
- reduce(dst, op=None, async_op=False, return_premature=False, group=None)¶
Reduces the tensordict across all machines.
Only the process with
rank
dst is going to receive the final result.
- refine_names(*names) T ¶
Refines the dimension names of self according to names.
Refining is a special case of renaming that “lifts” unnamed dimensions. A None dim can be refined to have any name; a named dim can only be refined to have the same name.
Because named tensors can coexist with unnamed tensors, refining names gives a nice way to write named-tensor-aware code that works with both named and unnamed tensors.
names may contain up to one Ellipsis (…). The Ellipsis is expanded greedily; it is expanded in-place to fill names to the same length as self.dim() using names from the corresponding indices of self.names.
Returns: the same tensordict with dimensions named according to the input.
Examples
>>> td = TensorDict({}, batch_size=[3, 4, 5, 6]) >>> tdr = td.refine_names(None, None, None, "d") >>> assert tdr.names == [None, None, None, "d"] >>> tdr = td.refine_names("a", None, None, "d") >>> assert tdr.names == ["a", None, None, "d"]
- rename(*names, **rename_map)¶
Returns a clone of the tensordict with dimensions renamed.
Examples
>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4]) >>> td.names = list("abcd") >>> td_rename = td.rename(c="g") >>> assert td_rename.names == list("abgd")
- rename_(*names, **rename_map)¶
Same as
rename()
, but executes the renaming in-place.Examples
>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4]) >>> td.names = list("abcd") >>> assert td.rename_(c="g") >>> assert td.names == list("abgd")
- rename_key_(old_key: NestedKey, new_key: NestedKey, safe: bool = False) T ¶
Renames a key with a new string and returns the same tensordict with the updated key name.
- replace(*args, **kwargs)¶
Creates a shallow copy of the tensordict where entries have been replaced.
Accepts one unnamed argument which must be a dictionary of a
TensorDictBase
subclass. Additionally, first-level entries can be updated with the named keyword arguments.- Returns:
a copy of
self
with updated entries if the input is non-empty. If an empty dict or no dict is provided and the kwargs are empty,self
is returned.
- requires_grad_(requires_grad=True) T ¶
Change if autograd should record operations on this tensor: sets this tensor’s requires_grad attribute in-place.
Returns this tensordict.
- Parameters:
requires_grad (bool, optional) – whether or not autograd should record operations on this tensordict. Defaults to
True
.
- reshape(*args, **kwargs) T ¶
Returns a contiguous, reshaped tensor of the desired shape.
- Parameters:
*shape (int) – new shape of the resulting tensordict.
- Returns:
A TensorDict with reshaped keys
Examples
>>> td = TensorDict({ ... 'x': torch.arange(12).reshape(3, 4), ... }, batch_size=[3, 4]) >>> td = td.reshape(12) >>> print(td['x']) torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
- round() T ¶
Computes the
round()
value of each element of the TensorDict.
- round_() T ¶
Computes the
round()
value of each element of the TensorDict in-place.
- save(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
Saves the tensordict to disk.
This function is a proxy to
memmap()
.
- property saved_path¶
Returns the path where a memmap saved TensorDict is being stored.
This argument valishes as soon as is_memmap() returns
False
(e.g., when the tensordict is unlocked).
- select(*keys: NestedKey, inplace: bool = False, strict: bool = True) T ¶
Selects the keys of the tensordict and returns a new tensordict with only the selected keys.
The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.
- Parameters:
- Returns:
A new tensordict (or the same if
inplace=True
) with the selected keys only.
Note
To select keys in a tensordict and return a version of this tensordict deprived of these keys, see the
split_keys()
method.Examples
>>> from tensordict import TensorDict >>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, []) >>> td.select("a", ("b", "c")) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> td.select("a", "b") TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> td.select("this key does not exist", strict=False) TensorDict( fields={ }, batch_size=torch.Size([]), device=None, is_shared=False)
- send(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) None ¶
Sends the content of a tensordict to a distant worker.
- Parameters:
dst (int) – the rank of the destination worker where the content should be sent.
- Keyword Arguments:
group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to
None
.init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.
pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to
False
.
Example
>>> from torch import multiprocessing as mp >>> from tensordict import TensorDict >>> import torch >>> >>> >>> def client(): ... torch.distributed.init_process_group( ... "gloo", ... rank=1, ... world_size=2, ... init_method=f"tcp://localhost:10003", ... ) ... ... td = TensorDict( ... { ... ("a", "b"): torch.randn(2), ... "c": torch.randn(2, 3), ... "_": torch.ones(2, 1, 5), ... }, ... [2], ... ) ... td.send(0) ... >>> >>> def server(queue): ... torch.distributed.init_process_group( ... "gloo", ... rank=0, ... world_size=2, ... init_method=f"tcp://localhost:10003", ... ) ... td = TensorDict( ... { ... ("a", "b"): torch.zeros(2), ... "c": torch.zeros(2, 3), ... "_": torch.zeros(2, 1, 5), ... }, ... [2], ... ) ... td.recv(1) ... assert (td != 0).all() ... queue.put("yuppie") ... >>> >>> if __name__=="__main__": ... queue = mp.Queue(1) ... main_worker = mp.Process(target=server, args=(queue,)) ... secondary_worker = mp.Process(target=client) ... ... main_worker.start() ... secondary_worker.start() ... out = queue.get(timeout=10) ... assert out == "yuppie" ... main_worker.join() ... secondary_worker.join()
- set(key: NestedKey, item: Tensor, inplace: bool = False, *, non_blocking: bool = False, **kwargs: Any) T ¶
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set.
item (torch.Tensor or equivalent, TensorDictBase instance) – value to be stored in the tensordict.
inplace (bool, optional) – if
True
and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If inplace isTrue
and the entry cannot be found, it will be added. For a more restrictive in-place operation, useset_()
instead. Defaults toFalse
.
- Keyword Arguments:
non_blocking (bool, optional) – if
True
and this copy is between different devices, the copy may occur asynchronously with respect to the host.- Returns:
self
Examples
>>> td = TensorDict({}, batch_size[3, 4]) >>> td.set("x", torch.randn(3, 4)) >>> y = torch.randn(3, 4, 5) >>> td.set("y", y, inplace=True) # works, even if 'y' is not present yet >>> td.set("y", torch.zeros_like(y), inplace=True) >>> assert (y==0).all() # y values are overwritten >>> td.set("y", torch.ones(5), inplace=True) # raises an exception as shapes mismatch
- set_(key: NestedKey, item: Tensor, *, non_blocking: bool = False) T ¶
Sets a value to an existing key while keeping the original storage.
- Parameters:
key (str) – name of the value
item (torch.Tensor or compatible type, TensorDictBase) – value to be stored in the tensordict
- Keyword Arguments:
non_blocking (bool, optional) – if
True
and this copy is between different devices, the copy may occur asynchronously with respect to the host.- Returns:
self
Examples
>>> td = TensorDict({}, batch_size[3, 4]) >>> x = torch.randn(3, 4) >>> td.set("x", x) >>> td.set_("x", torch.zeros_like(x)) >>> assert (x == 0).all()
- set_at_(key: NestedKey, value: Tensor, index: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], *, non_blocking: bool = False) T ¶
Sets the values in-place at the index indicated by
index
.- Parameters:
key (str, tuple of str) – key to be modified.
value (torch.Tensor) – value to be set at the index index
index (int, tensor or tuple) – index where to write the values.
- Keyword Arguments:
non_blocking (bool, optional) – if
True
and this copy is between different devices, the copy may occur asynchronously with respect to the host.- Returns:
self
Examples
>>> td = TensorDict({}, batch_size[3, 4]) >>> x = torch.randn(3, 4) >>> td.set("x", x) >>> td.set_at_("x", value=torch.ones(1, 4), index=slice(1)) >>> assert (x[0] == 1).all()
- set_non_tensor(key: NestedKey, value: Any)¶
Registers a non-tensor value in the tensordict using
tensordict.tensorclass.NonTensorData
.The value can be retrieved using
TensorDictBase.get_non_tensor()
or directly using get, which will return thetensordict.tensorclass.NonTensorData
object.return: self
Examples
>>> data = TensorDict({}, batch_size=[]) >>> data.set_non_tensor(("nested", "the string"), "a string!") >>> assert data.get_non_tensor(("nested", "the string")) == "a string!" >>> # regular `get` works but returns a NonTensorData object >>> data.get(("nested", "the string")) NonTensorData( data='a string!', batch_size=torch.Size([]), device=None, is_shared=False)
- setdefault(key: NestedKey, default: Tensor, inplace: bool = False) Tensor ¶
Insert the
key
entry with a value ofdefault
ifkey
is not in the tensordict.Return the value for
key
ifkey
is in the tensordict, elsedefault
.- Parameters:
key (str or nested key) – the name of the value.
default (torch.Tensor or compatible type, TensorDictBase) – value to be stored in the tensordict if the key is not already present.
- Returns:
The value of key in the tensordict. Will be default if the key was not previously set.
Examples
>>> td = TensorDict({}, batch_size=[3, 4]) >>> val = td.setdefault("a", torch.zeros(3, 4)) >>> assert (val == 0).all() >>> val = td.setdefault("a", torch.ones(3, 4)) >>> assert (val == 0).all() # output is still 0
- property shape: Size¶
See
batch_size
.
Places all the tensors in shared memory.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Conversely, once the tensordict is unlocked, the share_memory attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
self
- sigmoid() T ¶
Computes the
sigmoid()
value of each element of the TensorDict.
- sigmoid_() T ¶
Computes the
sigmoid()
value of each element of the TensorDict in-place.
- sign() T ¶
Computes the
sign()
value of each element of the TensorDict.
- sign_() T ¶
Computes the
sign()
value of each element of the TensorDict in-place.
- sin() T ¶
Computes the
sin()
value of each element of the TensorDict.
- sin_() T ¶
Computes the
sin()
value of each element of the TensorDict in-place.
- sinh() T ¶
Computes the
sinh()
value of each element of the TensorDict.
- sinh_() T ¶
Computes the
sinh()
value of each element of the TensorDict in-place.
- size(dim: Optional[int] = None) torch.Size | int ¶
Returns the size of the dimension indicated by
dim
.If
dim
is not specified, returns thebatch_size
attribute of the TensorDict.
- property sorted_keys: list[tensordict._nestedkey.NestedKey]¶
Returns the keys sorted in alphabetical order.
Does not support extra arguments.
If the TensorDict is locked, the keys are cached until the tensordict is unlocked for faster execution.
- split(split_size: int | list[int], dim: int = 0) list[tensordict.base.TensorDictBase] ¶
Splits each tensor in the TensorDict with the specified size in the given dimension, like torch.split.
Returns a list of
TensorDict
instances with the view of split chunks of items.- Parameters:
- Returns:
A list of TensorDict with specified size in given dimension.
Examples
>>> td = TensorDict({ ... 'x': torch.arange(12).reshape(3, 4), ... }, batch_size=[3, 4]) >>> td0, td1 = td.split([1, 2], dim=0) >>> print(td0['x']) torch.Tensor([[0, 1, 2, 3]])
- split_keys(*key_sets, inplace=False, strict: bool = True, reproduce_struct: bool = False)¶
Splits the tensordict in subsets given one or more set of keys.
The method will return
N+1
tensordicts, whereN
is the number of the arguments provided.- Parameters:
inplace (bool, optional) – if
True
, the keys are removed fromself
in-place. Defaults toFalse
.strict (bool, optional) – if
True
, an exception is raised when a key is missing. Defaults toTrue
.reproduce_struct (bool, optional) – if
True
, all tensordict returned have the same tree structure asself
, even if some sub-tensordicts contain no leaves.
Note
None
non-tensor values will be ignored and not returned.Note
The method does not check for duplicates in the provided lists.
Examples
>>> td = TensorDict( ... a=0, ... b=0, ... c=0, ... d=0, ... ) >>> td_a, td_bc, td_d = td.split_keys(["a"], ["b", "c"]) >>> print(td_bc)
- sqrt()¶
Computes the element-wise square root of
self
.
- squeeze(*args, **kwargs)¶
Squeezes all tensors for a dimension in between -self.batch_dims+1 and self.batch_dims-1 and returns them in a new tensordict.
- Parameters:
dim (Optional[int]) – dimension along which to squeeze. If dim is
None
, all singleton dimensions will be squeezed. Defaults toNone
.
Examples
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 1, 4, 2), ... }, batch_size=[3, 1, 4]) >>> td = td.squeeze() >>> td.shape torch.Size([3, 4]) >>> td.get("x").shape torch.Size([3, 4, 2])
This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary). This functionality is not compatible with implicit squeezing.
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 1, 4, 2), ... }, batch_size=[3, 1, 4]) >>> with td.squeeze(1) as tds: ... tds.set("y", torch.zeros(3, 4)) >>> assert td.get("y").shape == [3, 1, 4]
- classmethod stack(input, dim=0, *, out=None)¶
Stacks tensordicts into a single tensordict along the given dimension.
This call is equivalent to calling
torch.stack()
but is compatible with torch.compile.
- stack_from_tensordict(dim: int = 0, *, sorted: Optional[Union[bool, List[NestedKey]]] = None, out: Optional[Tensor] = None) Tensor ¶
Stacks all entries of a tensordict in a single tensor.
- Parameters:
dim (int, optional) – the dimension along which the entries should be stacked.
- Keyword Arguments:
sorted (bool or list of NestedKeys) – if
True
, the entries will be stacked in alphabetical order. IfFalse
(default), the dict order will be used. Alternatively, a list of key names can be provided and the tensors will be stacked accordingly. This incurs some overhead as the list of keys will be checked against the list of leaf names in the tensordict.out (torch.Tensor, optional) – an optional destination tensor for the stack operation.
- stack_tensors(*keys: NestedKey, out_key: NestedKey, dim: int = 0, keep_entries: bool = False) T ¶
Stacks entries into a new entry and possibly remove the original values.
- Parameters:
keys (sequence of NestedKey) – entries to stack.
- Keyword Argument:
out_key (NestedKey): new key name for the stacked inputs. keep_entries (bool, optional): if
False
, entries inkeys
will be deleted.Defaults to
False
.- dim (int, optional): the dimension along which the stack must occur.
Defaults to
0
.
Returns: self
Examples
>>> td = TensorDict(a=torch.zeros(()), b=torch.ones(())) >>> td.stack_tensors("a", "b", out_key="c") >>> assert "a" not in td >>> assert (td["c"] == torch.tensor([0, 1])).all()
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) OrderedDict[str, Any] ¶
Produces a state_dict from the tensordict.
The structure of the state-dict will still be nested, unless
flatten
is set toTrue
.A tensordict state-dict contains all the tensors and meta-data needed to rebuild the tensordict (names are currently not supported).
- Parameters:
destination (dict, optional) – If provided, the state of tensordict will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to tensor names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
torch.Tensor
items returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.flatten (bool, optional) – whether the structure should be flattened with the
"."
character or not. Defaults toFalse
.
Examples
>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) >>> sd = data.state_dict() >>> print(sd) OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3', OrderedDict([('3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])), ('__batch_size', torch.Size([])), ('__device', None)]) >>> sd = data.state_dict(flatten=True) OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3.3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])
- std(dim: Union[int, Tuple[int]] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, correction: int = 1, reduce: Optional[bool] = None) tensordict.base.TensorDictBase | torch.Tensor ¶
Returns the standard deviation value of all elements in the input tensordict.
- Parameters:
dim (int, tuple of int, optional) – if
None
, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, std is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.keepdim (bool) – whether the output tensor has dim retained or not.
- Keyword Arguments:
- sub(other: tensordict.base.TensorDictBase | float, *, alpha: Optional[float] = None, default: Optional[Union[str, Tensor]] = None)¶
Subtracts
other
, scaled byalpha
, fromself
.\[\text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i\]Supports broadcasting, type promotion, and integer, float, and complex inputs.
- Parameters:
other (TensorDict, Tensor or Number) – the tensor or number to subtract from
self
.- Keyword Arguments:
alpha (Number) – the multiplier for
other
.default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If
default="intersection"
is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases,default
will be used for all missing entries on both sides of the operation.
- sub_(other: tensordict.base.TensorDictBase | float, alpha: Optional[float] = None)¶
In-place version of
sub()
.Note
In-place
sub
does not supportdefault
keyword argument.
- sum(dim: Union[int, Tuple[int]] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: Optional[dtype] = None, reduce: Optional[bool] = None) tensordict.base.TensorDictBase | torch.Tensor ¶
Returns the sum value of all elements in the input tensordict.
- Parameters:
dim (int, tuple of int, optional) – if
None
, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.keepdim (bool) – whether the output tensor has dim retained or not.
- Keyword Arguments:
dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default:
None
.reduce (bool, optional) – if
True
, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults toFalse
.
- tan() T ¶
Computes the
tan()
value of each element of the TensorDict.
- tan_() T ¶
Computes the
tan()
value of each element of the TensorDict in-place.
- tanh() T ¶
Computes the
tanh()
value of each element of the TensorDict.
- tanh_() T ¶
Computes the
tanh()
value of each element of the TensorDict in-place.
- to(*args, **kwargs) T ¶
Maps a TensorDictBase subclass either on another device, dtype or to another TensorDictBase subclass (if permitted).
Casting tensors to a new dtype is not allowed, as tensordicts are not bound to contain a single tensor dtype.
- Parameters:
device (torch.device, optional) – the desired device of the tensordict.
dtype (torch.dtype, optional) – the desired floating point or complex dtype of the tensordict.
tensor (torch.Tensor, optional) – Tensor whose dtype and device are the desired dtype and device for all tensors in this TensorDict.
- Keyword Arguments:
non_blocking (bool, optional) – whether the operations should be blocking.
memory_format (torch.memory_format, optional) – the desired memory format for 4D parameters and buffers in this tensordict.
batch_size (torch.Size, optional) – resulting batch-size of the output tensordict.
other (TensorDictBase, optional) –
TensorDict instance whose dtype and device are the desired dtype and device for all tensors in this TensorDict.
Note
Since
TensorDictBase
instances do not have a dtype, the dtype is gathered from the example leaves. If there are more than one dtype, then no dtype casting is undertook.non_blocking_pin (bool, optional) –
if
True
, the tensors are pinned before being sent to device. This will be done asynchronously but can be controlled via thenum_threads
argument.Note
Calling
tensordict.pin_memory().to("cuda")
will usually be much slower thantensordict.to("cuda", non_blocking_pin=True)
as the pin_memory is called asynchronously in the second case. Multithreadedpin_memory
will usually be beneficial if the tensors are large and numerous: when there are too few tensors to be sent, the overhead of spawning threads and collecting data outweighs the benefits of multithreading, and if the tensors are small the overhead of iterating over a long list is also prohibitively large.num_threads (int or None, optional) – if
non_blocking_pin=True
, the number of threads to be used forpin_memory
. By default,max(1, torch.get_num_threads())
threads will be spawn.num_threads=0
will cancel any multithreading for the pin_memory() calls.
- Returns:
a new tensordict instance if the device differs from the tensordict device and/or if the dtype is passed. The same tensordict otherwise.
batch_size
only modifications are done in-place.
Note
If the TensorDict is consolidated, the resulting TensorDict will be consolidated too. Each new tensor will be a view on the consolidated storage cast to the desired device.
Examples
>>> data = TensorDict({"a": 1.0}, [], device=None) >>> data_cuda = data.to("cuda:0") # casts to cuda >>> data_int = data.to(torch.int) # casts to int >>> data_cuda_int = data.to("cuda:0", torch.int) # multiple casting >>> data_cuda = data.to(torch.randn(3, device="cuda:0")) # using an example tensor >>> data_cuda = data.to(other=TensorDict({}, [], device="cuda:0")) # using a tensordict example
- to_dict(*, retain_none: bool = True) dict[str, Any] ¶
Returns a dictionary with key-value pairs matching those of the tensordict.
- Parameters:
retain_none (bool) – if
True
, theNone
values from tensorclass instances will be written in the dictionary. Otherwise, they will be discarded. Default:True
.
- to_h5(filename, **kwargs)¶
Converts a tensordict to a PersistentTensorDict with the h5 backend.
- Parameters:
filename (str or path) – path to the h5 file.
device (torch.device or compatible, optional) – the device where to expect the tensor once they are returned. Defaults to
None
(on cpu by default).**kwargs – kwargs to be passed to
h5py.File.create_dataset()
.
- Returns:
A
PersitentTensorDict
instance linked to the newly created file.
Examples
>>> import tempfile >>> import timeit >>> >>> from tensordict import TensorDict, MemoryMappedTensor >>> td = TensorDict({ ... "a": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000)), ... "b": {"c": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000, 3))}, ... }, [1_000_000]) >>> >>> file = tempfile.NamedTemporaryFile() >>> td_h5 = td.to_h5(file.name, compression="gzip", compression_opts=9) >>> print(td_h5) PersistentTensorDict( fields={ a: Tensor(shape=torch.Size([1000000]), device=cpu, dtype=torch.float32, is_shared=False), b: PersistentTensorDict( fields={ c: Tensor(shape=torch.Size([1000000, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([1000000]), device=None, is_shared=False)}, batch_size=torch.Size([1000000]), device=None, is_shared=False)
- to_module(module: Module, *, inplace: bool | None = None, return_swap: bool = True, swap_dest=None, use_state_dict: bool = False, non_blocking: bool = False, memo=None)¶
Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.
- Parameters:
module (nn.Module) – a module to write the parameters into.
- Keyword Arguments:
inplace (bool, optional) – if
True
, the parameters or tensors in the module are updated in-place. Defaults toFalse
.return_swap (bool, optional) – if
True
, the old parameter configuration will be returned. Defaults toFalse
.swap_dest (TensorDictBase, optional) – if
return_swap
isTrue
, the tensordict where the swap should be written.use_state_dict (bool, optional) – if
True
, state-dict API will be used to load the parameters (including the state-dict hooks). Defaults toFalse
.non_blocking (bool, optional) – if
True
and this copy is between different devices, the copy may occur asynchronously with respect to the host.
Examples
>>> from torch import nn >>> module = nn.TransformerDecoder( ... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4), ... num_layers=1) >>> params = TensorDict.from_module(module) >>> params.zero_() >>> params.to_module(module) >>> assert (module.layers[0].linear1.weight == 0).all()
- to_namedtuple(dest_cls: Optional[type] = None)¶
Converts a tensordict to a namedtuple.
- Parameters:
dest_cls (Type, optional) – an optional namedtuple class to use.
Examples
>>> from tensordict import TensorDict >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> data.to_namedtuple() GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!'))
- to_padded_tensor(padding=0.0, mask_key: Optional[NestedKey] = None)¶
Converts all nested tensors to a padded version and adapts the batch-size accordingly.
- Parameters:
padding (float) – the padding value for the tensors in the tensordict. Defaults to
0.0
.mask_key (NestedKey, optional) – if provided, the key where a mask for valid values will be written. Will result in an error if the heterogeneous dimension isn’t part of the tensordict batch-size. Defaults to
None
- to_pytree()¶
Converts a tensordict to a PyTree.
If the tensordict was not created from a pytree, this method just returns
self
without modification.See
from_pytree()
for more information and examples.
- to_struct_array()¶
Converts a tensordict to a numpy structured array.
In a
from_struct_array()
-to_struct_array()
loop, the content of the input and output arrays should match. However, to_struct_array will not keep the memory content of the original arrays.See
from_struct_array()
for more information.
- to_tensordict(*, retain_none: Optional[bool] = None) T ¶
Returns a regular TensorDict instance from the TensorDictBase.
- Parameters:
retain_none (bool) –
if
True
, theNone
values from tensorclass instances will be written in the tensordict. Otherwise they will be discarded. Default:True
.Note
from v0.8, the default value will be switched to
False
.- Returns:
a new TensorDict object containing the same values.
- transpose(dim0, dim1)¶
Returns a tensordict that is a transposed version of input. The given dimensions
dim0
anddim1
are swapped.In-place or out-place modifications of the transposed tensordict will impact the original tensordict too as the memory is shared and the operations are mapped back on the original tensordict.
Examples
>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4]) >>> tensordict_transpose = tensordict.transpose(0, 1) >>> print(tensordict_transpose.shape) torch.Size([4, 3]) >>> tensordict_transpose.set("b",, torch.randn(4, 3)) >>> print(tensordict.get("b").shape) torch.Size([3, 4])
- trunc() T ¶
Computes the
trunc()
value of each element of the TensorDict.
- trunc_() T ¶
Computes the
trunc()
value of each element of the TensorDict in-place.
- type(dst_type)¶
Casts all tensors to
dst_type
.- Parameters:
dst_type (type or string) – the desired type
- uint16()¶
Casts all tensors to
torch.uint16
.
- uint32()¶
Casts all tensors to
torch.uint32
.
- uint64()¶
Casts all tensors to
torch.uint64
.
- uint8()¶
Casts all tensors to
torch.uint8
.
- unbind(dim: int) tuple[T, ...] ¶
Returns a tuple of indexed tensordicts, unbound along the indicated dimension.
Examples
>>> td = TensorDict({ ... 'x': torch.arange(12).reshape(3, 4), ... }, batch_size=[3, 4]) >>> td0, td1, td2 = td.unbind(0) >>> td0['x'] tensor([0, 1, 2, 3]) >>> td1['x'] tensor([4, 5, 6, 7])
- unflatten(dim, unflattened_size)¶
Unflattens a tensordict dim expanding it to a desired shape.
- Parameters:
dim (int) – specifies the dimension of the input tensor to be unflattened.
unflattened_size (shape) – is the new shape of the unflattened dimension of the tensordict.
Examples
>>> td = TensorDict({ ... "a": torch.arange(60).view(3, 4, 5), ... "b": torch.arange(12).view(3, 4)}, ... batch_size=[3, 4]) >>> td_flat = td.flatten(0, 1) >>> td_unflat = td_flat.unflatten(0, [3, 4]) >>> assert (td == td_unflat).all()
- unflatten_keys(separator: str = '.', inplace: bool = False) T ¶
Converts a flat tensordict into a nested one, recursively.
The TensorDict type will be lost and the result will be a simple TensorDict instance. The metadata of the nested tensordicts will be inferred from the root: all instances across the data tree will share the same batch-size, dimension names and device.
- Parameters:
Examples
>>> data = TensorDict({"a": 1, "b - c": 2, "e - f - g": 3}, batch_size=[]) >>> data.unflatten_keys(separator=" - ") TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), e: TensorDict( fields={ f: TensorDict( fields={ g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This method and
unflatten_keys()
are particularly useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.Examples
>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4)) >>> ddp_model = torch.ao.quantization.QuantWrapper(model) >>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".") >>> print(state_dict) TensorDict( fields={ module: TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model_state_dict = state_dict.get("module") >>> print(model_state_dict) TensorDict( fields={ 0: TensorDict( fields={ bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
- unlock_() T ¶
Unlocks a tensordict for non in-place operations.
Can be used as a decorator.
See
lock_()
for more details.
- unsqueeze(*args, **kwargs)¶
Unsqueezes all tensors for a dimension comprised in between -td.batch_dims and td.batch_dims and returns them in a new tensordict.
- Parameters:
dim (int) – dimension along which to unsqueeze
Examples
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 4, 2), ... }, batch_size=[3, 4]) >>> td = td.unsqueeze(-2) >>> td.shape torch.Size([3, 1, 4]) >>> td.get("x").shape torch.Size([3, 1, 4, 2])
This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary).
>>> td = TensorDict({ ... 'x': torch.arange(24).reshape(3, 4, 2), ... }, batch_size=[3, 4]) >>> with td.unsqueeze(-2) as tds: ... tds.set("y", torch.zeros(3, 1, 4)) >>> assert td.get("y").shape == [3, 4]
- update(input_dict_or_td: Union[dict[str, torch.Tensor], T], clone: bool = False, inplace: bool = False, *, non_blocking: bool = False, keys_to_update: Optional[Sequence[NestedKey]] = None, is_leaf: Optional[Callable[[Type], bool]] = None) T ¶
Updates the TensorDict with values from either a dictionary or another TensorDict.
- Parameters:
input_dict_or_td (TensorDictBase or dict) – input data to be written in self.
clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to
False
.inplace (bool, optional) – if
True
and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If the entry cannot be found, it will be added. Defaults toFalse
.
- Keyword Arguments:
keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in
key_to_update
will be updated. This is aimed at avoiding calls todata_dest.update(data_src.select(*keys_to_update))
.non_blocking (bool, optional) – if
True
and this copy is between different devices, the copy may occur asynchronously with respect to the host.is_leaf (Callable[[Type], bool], optional) – a callable that indicates whether an object type is to be considered a leaf and swapped or a tensor collection.
- Returns:
self
Examples
>>> td = TensorDict({}, batch_size=[3]) >>> a = torch.randn(3) >>> b = torch.randn(3, 4) >>> other_td = TensorDict({"a": a, "b": b}, batch_size=[]) >>> td.update(other_td, inplace=True) # writes "a" and "b" even though they can't be found >>> assert td['a'] is other_td['a'] >>> other_td = other_td.clone().zero_() >>> td.update(other_td) >>> assert td['a'] is not other_td['a']
- update_(input_dict_or_td: Union[dict[str, torch.Tensor], T], clone: bool = False, *, non_blocking: bool = False, keys_to_update: Optional[Sequence[NestedKey]] = None) T ¶
Updates the TensorDict in-place with values from either a dictionary or another TensorDict.
Unlike
update()
, this function will throw an error if the key is unknown toself
.- Parameters:
input_dict_or_td (TensorDictBase or dict) – input data to be written in self.
clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to
False
.
- Keyword Arguments:
keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in
key_to_update
will be updated. This is aimed at avoiding calls todata_dest.update_(data_src.select(*keys_to_update))
.non_blocking (bool, optional) – if
True
and this copy is between different devices, the copy may occur asynchronously with respect to the host.
- Returns:
self
Examples
>>> a = torch.randn(3) >>> b = torch.randn(3, 4) >>> td = TensorDict({"a": a, "b": b}, batch_size=[3]) >>> other_td = TensorDict({"a": a*0, "b": b*0}, batch_size=[]) >>> td.update_(other_td) >>> assert td['a'] is not other_td['a'] >>> assert (td['a'] == other_td['a']).all() >>> assert (td['a'] == 0).all()
- update_at_(input_dict_or_td: Union[dict[str, torch.Tensor], T], idx: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], clone: bool = False, *, non_blocking: bool = False, keys_to_update: Optional[Sequence[NestedKey]] = None) T ¶
Updates the TensorDict in-place at the specified index with values from either a dictionary or another TensorDict.
Unlike TensorDict.update, this function will throw an error if the key is unknown to the TensorDict.
- Parameters:
input_dict_or_td (TensorDictBase or dict) – input data to be written in self.
idx (int, torch.Tensor, iterable, slice) – index of the tensordict where the update should occur.
clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Default is False.
- Keyword Arguments:
keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in
key_to_update
will be updated.non_blocking (bool, optional) – if
True
and this copy is between different devices, the copy may occur asynchronously with respect to the host.
- Returns:
self
Examples
>>> td = TensorDict({ ... 'a': torch.zeros(3, 4, 5), ... 'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4]) >>> td.update_at_( ... TensorDict({ ... 'a': torch.ones(1, 4, 5), ... 'b': torch.ones(1, 4, 10)}, batch_size=[1, 4]), ... slice(1, 2)) TensorDict( fields={ a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32), b: Tensor(torch.Size([3, 4, 10]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False) >>> assert (td[1] == 1).all()
- values(include_nested: bool = False, leaves_only: bool = False, is_leaf: Optional[Callable[[Type], bool]] = None, *, sort: bool = False) Iterator[tuple[str, torch.Tensor]] ¶
Returns a generator representing the values for the tensordict.
- Parameters:
- Keyword Arguments:
sort (bool, optional) – whether the keys should be sorted. For nested keys, the keys are sorted according to their joined name (ie,
("a", "key")
will be counted as"a.key"
for sorting). Be mindful that sorting may incur significant overhead when dealing with large tensordicts. Defaults toFalse
.
- var(dim: Union[int, Tuple[int]] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, correction: int = 1, reduce: Optional[bool] = None) tensordict.base.TensorDictBase | torch.Tensor ¶
Returns the variance value of all elements in the input tensordict.
- Parameters:
dim (int, tuple of int, optional) – if
None
, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, var is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.keepdim (bool) – whether the output tensor has dim retained or not.
- Keyword Arguments:
- view(*shape: int, size: list | tuple | torch.Size | None = None, batch_size: torch.Size | None = None)¶
Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size.
Alternatively, a dtype can be provided as a first unnamed argument. In that case, all tensors will be viewed with the according dtype. Note that this assume that the new shapes will be compatible with the provided dtype. See
view()
for more information on dtype views.- Parameters:
*shape (int) – new shape of the resulting tensordict.
dtype (torch.dtype) – alternatively, a dtype to use to represent the tensor content.
size – iterable
- Keyword Arguments:
batch_size (torch.Size, optional) – if a dtype is provided, the batch-size can be reset using this keyword argument. If the
view
is called with a shape, this is without effect.- Returns:
a new tensordict with the desired batch_size.
Examples
>>> td = TensorDict(source={'a': torch.zeros(3,4,5), ... 'b': torch.zeros(3,4,10,1)}, batch_size=torch.Size([3, 4])) >>> td_view = td.view(12) >>> print(td_view.get("a").shape) # torch.Size([12, 5]) >>> print(td_view.get("b").shape) # torch.Size([12, 10, 1]) >>> td_view = td.view(-1, 4, 3) >>> print(td_view.get("a").shape) # torch.Size([1, 4, 3, 5]) >>> print(td_view.get("b").shape) # torch.Size([1, 4, 3, 10, 1])
- where(condition, other, *, out=None, pad=None)¶
Return a
TensorDict
of elements selected from either self or other, depending on condition.- Parameters:
condition (BoolTensor) – When
True
(nonzero), yieldsself
, otherwise yieldsother
.other (TensorDictBase or Scalar) – value (if
other
is a scalar) or values selected at indices where condition isFalse
.
- Keyword Arguments:
out (TensorDictBase, optional) – the output
TensorDictBase
instance.pad (scalar, optional) – if provided, missing keys from the source or destination tensordict will be written as torch.where(mask, self, pad) or torch.where(mask, pad, other). Defaults to
None
, ie missing keys are not tolerated.
- zero_() T ¶
Zeros all tensors in the tensordict in-place.