

class tensordict.PersistentTensorDict(*, batch_size=None, filename=None, group=None, mode='r', backend='h5', device=None, **kwargs)

Persistent TensorDict implementation.

PersistentTensorDict instances provide an interface with data stored on disk such that access to this data is made easy while still taking advantage from the fast access provided by the backend.

Like other TensorDictBase subclasses, PersistentTensorDict has a device attribute. This does not mean that the data is being stored on that device, but rather that when loaded, the data will be cast onto the desired device.

Keyword Arguments:
  • batch_size (torch.Size or compatible) – the tensordict batch size. Defaults to torch.Size(()).

  • filename (str, optional) – the path to the h5 file. Exclusive with group.

  • group (h5py.Group, optional) – a file or a group that contains data. Exclusive with filename.

  • mode (str, optional) – Reading mode. Defaults to "r".

  • backend (str, optional) – storage backend. Currently only "h5" is supported.

  • device (torch.device or compatible, optional) – device of the tensordict. Defaults to None (ie. default PyTorch device).

  • **kwargs – kwargs to be passed to h5py.File.create_dataset().


Currently, PersistentTensorDict instances are not closed when getting out-of-scope. This means that it is the responsibility of the user to close them if necessary.


>>> import tempfile
>>> with tempfile.NamedTemporaryFile() as f:
...     data = PersistentTensorDict(file=f, batch_size=[3], mode="w")
...     data["a", "b"] = torch.randn(3, 4)
...     print(data)
abs() T

Computes the absolute value of each element of the TensorDict.

abs_() T

Computes the absolute value of each element of the TensorDict in-place.

acos() T

Computes the acos() value of each element of the TensorDict.

acos_() T

Computes the acos() value of each element of the TensorDict in-place.

add(other: TensorDictBase | torch.Tensor, *, alpha: float | None = None, default: str | CompatibleType | None = None) TensorDictBase

Adds other, scaled by alpha, to self.

\[\text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i\]

other (TensorDictBase or torch.Tensor) – the tensor or TensorDict to add to self.

Keyword Arguments:
  • alpha (Number, optional) – the multiplier for other.

  • default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

add_(other: TensorDictBase | float, *, alpha: float | None = None)

In-place version of add().


inplace add does not support default keyword argument.

addcdiv(other1: TensorDictBase | torch.Tensor, other2: TensorDictBase | torch.Tensor, value: float | None = 1)

Performs the element-wise division of other1 by other2, multiplies the result by the scalar value and adds it to self.

\[\text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i}\]

The shapes of the elements of self, other1, and other2 must be broadcastable.

For inputs of type FloatTensor or DoubleTensor, value must be a real number, otherwise an integer.

  • other1 (TensorDict or Tensor) – the numerator tensordict (or tensor)

  • tensor2 (TensorDict or Tensor) – the denominator tensordict (or tensor)

Keyword Arguments:

value (Number, optional) – multiplier for \(\text{tensor1} / \text{tensor2}\)

addcdiv_(other1, other2, *, value: float | None = 1)

The in-place version of addcdiv().

addcmul(other1, other2, *, value: float | None = 1)

Performs the element-wise multiplication of other1 by other2, multiplies the result by the scalar value and adds it to self.

\[\text{out}_i = \text{input}_i + \text{value} \times \text{other1}_i \times \text{other2}_i\]

The shapes of self, other1, and other2 must be broadcastable.

For inputs of type FloatTensor or DoubleTensor, value must be a real number, otherwise an integer.

  • other1 (TensorDict or Tensor) – the tensordict or tensor to be multiplied

  • other2 (TensorDict or Tensor) – the tensordict or tensor to be multiplied

Keyword Arguments:

value (Number, optional) – multiplier for \(other1 .* other2\)

addcmul_(other1, other2, *, value: float | None = 1)

The in-place version of addcmul().

all(dim: int = None) bool | TensorDictBase

Checks if all values are True/non-null in the tensordict.


dim (int, optional) – if None, returns a boolean indicating whether all tensors return tensor.all() == True If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

any(dim: int = None) bool | TensorDictBase

Checks if any value is True/non-null in the tensordict.


dim (int, optional) – if None, returns a boolean indicating whether all tensors return tensor.any() == True. If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

apply(fn: Callable, *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = _NoDefault.ZERO, names: Sequence[str] | None = _NoDefault.ZERO, inplace: bool = False, default: Any = _NoDefault.ZERO, filter_empty: bool | None = None, propagate_lock: bool = False, call_on_nested: bool = False, out: TensorDictBase | None = None, **constructor_kwargs) T | None

Applies a callable to all values stored in the tensordict and sets them in a new tensordict.

The callable signature must be Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]].

  • fn (Callable) – function to be applied to the tensors in the tensordict.

  • *others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The fn argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through the default keyword argument.

Keyword Arguments:
  • batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The batch_size argument should match the batch_size after the transformation. This is a keyword only argument.

  • device (torch.device, optional) – the resulting device, if any.

  • names (list of str, optional) – the new dimension names, in case the batch_size is modified.

  • inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.

  • default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.

  • filter_empty (bool, optional) – if True, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Non-tensor data is considered as a leaf and thereby will be kept in the tensordict even if left untouched by the function. Defaults to False for backward compatibility.

  • propagate_lock (bool, optional) – if True, a locked tensordict will produce another locked tensordict. Defaults to False.

  • call_on_nested (bool, optional) –

    if True, the function will be called on first-level tensors and containers (TensorDict or tensorclass). In this scenario, func is responsible of propagating its calls to nested levels. This allows a fine-grained behaviour when propagating the calls to nested tensordicts. If False, the function will only be called on leaves, and apply will take care of dispatching the function to all leaves.

    >>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]})
    >>> def mean_tensor_only(val):
    ...     if is_tensor_collection(val):
    ...         raise RuntimeError("Unexpected!")
    ...     return val.mean()
    >>> td_mean = td.apply(mean_tensor_only)
    >>> def mean_any(val):
    ...     if is_tensor_collection(val):
    ...         # Recurse
    ...         return val.apply(mean_any, call_on_nested=True)
    ...     return val.mean()
    >>> td_mean = td.apply(mean_any, call_on_nested=True)

  • out (TensorDictBase, optional) –

    a tensordict where to write the results. This can be used to avoid creating a new tensordict:

    >>> td = TensorDict({"a": 0})
    >>> td.apply(lambda x: x+1, out=td)
    >>> assert (td==1).all()


    If the operation executed on the tensordict requires multiple keys to be accessed for a single computation, providing an out argument equal to self can cause the operation to provide silently wrong results. For instance:

    >>> td = TensorDict({"a": 1, "b": 1})
    >>> td.apply(lambda x: x+td["a"])["b"] # Right!
    >>> td.apply(lambda x: x+td["a"], out=td)["b"] # Wrong!

  • **constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.


a new tensordict with transformed_in tensors.


>>> td = TensorDict({
...     "a": -torch.ones(3),
...     "b": {"c": torch.ones(3)}},
...     batch_size=[3])
>>> td_1 = td.apply(lambda x: x+1)
>>> assert (td_1["a"] == 0).all()
>>> assert (td_1["b", "c"] == 2).all()
>>> td_2 = td.apply(lambda x, y: x+y, td)
>>> assert (td_2["a"] == -2).all()
>>> assert (td_2["b", "c"] == 2).all()


If None is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def filter(tensor):
...     if tensor == 1:
...         return tensor
>>> td.apply(filter)
        1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
                1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},


The apply method will return an TensorDict instance, regardless of the input type. To keep the same type, one can execute

>>> out = td.clone(False).update(td.apply(...))
apply_(fn: Callable, *others, **kwargs) T

Applies a callable to all values stored in the tensordict and re-writes them in-place.

  • fn (Callable) – function to be applied to the tensors in the tensordict.

  • *others (sequence of TensorDictBase, optional) – the other tensordicts to be used.

Keyword Args: See apply().


self or a copy of self with the function applied

asin() T

Computes the asin() value of each element of the TensorDict.

asin_() T

Computes the asin() value of each element of the TensorDict in-place.

atan() T

Computes the atan() value of each element of the TensorDict.

atan_() T

Computes the atan() value of each element of the TensorDict in-place.

auto_batch_size_(batch_dims: int | None = None) T

Sets the maximum batch-size for the tensordict, up to an optional batch_dims.


batch_dims (int, optional) – if provided, the batch-size will be at most batch_dims long.




>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[])
>>> td.auto_batch_size_()
>>> print(td.batch_size)
torch.Size([3, 4])
>>> td.auto_batch_size_(batch_dims=1)
>>> print(td.batch_size)
property batch_dims: int

Length of the tensordict batch size.


int describing the number of dimensions of the tensordict.

property batch_size

Shape (or batch_size) of a TensorDict.

The shape of a tensordict corresponds to the common first N dimensions of the tensors it contains, where N is an arbitrary number. The batch-size contrasts with the “feature size” which repesents the semantically relevant shapes of a tensor. For instance, a batch of videos may have shape [B, T, C, W, H], where [B, T] is the batch-size (batch and time dimensions) and [C, W, H] are the feature dimensions (channels and spacial dimensions).

The TensorDict shape is controlled by the user upon initialization (ie, it is not inferred from the tensor shapes).

The batch_size can be edited dynamically if the new size is compatible with the TensorDict content. For instance, setting the batch size to an empty value is always allowed.


a Size object describing the TensorDict batch size.


>>> data = TensorDict({
...     "key 0": torch.randn(3, 4),
...     "key 1": torch.randn(3, 5),
...     "nested": TensorDict({"key 0": torch.randn(3, 4)}, batch_size=[3, 4])},
...     batch_size=[3])
>>> data.batch_size = () # resets the batch-size to an empty value

Casts all tensors to torch.bfloat16.


Casts all tensors to torch.bool.

classmethod cat(input, dim=0, *, out=None)

Concatenates tensordicts into a single tensordict along the given dimension.

This call is equivalent to calling but is compatible with torch.compile.

ceil() T

Computes the ceil() value of each element of the TensorDict.

ceil_() T

Computes the ceil() value of each element of the TensorDict in-place.

chunk(chunks: int, dim: int = 0) tuple[TensorDictBase, ...]

Splits a tensordict into the specified number of chunks, if possible.

Each chunk is a view of the input tensordict.

  • chunks (int) – number of chunks to return

  • dim (int, optional) – dimension along which to split the tensordict. Default is 0.


>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> td0, td1 = td.chunk(dim=-1, chunks=2)
>>> td0['x']
tensor([[[ 0,  1],
         [ 2,  3]],
        [[ 8,  9],
         [10, 11]],
        [[16, 17],
         [18, 19]]])
clamp_max(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T

Clamps the elements of self to other if they’re superior to that value.


other (TensorDict or Tensor) – the other input tensordict or tensor.

Keyword Arguments:

default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

clamp_max_(other: TensorDictBase | torch.Tensor) T

In-place version of clamp_max().


inplace clamp_max does not support default keyword argument.

clamp_min(other: TensorDictBase | torch.Tensor, default: str | CompatibleType | None = None) T

Clamps the elements of self to other if they’re inferior to that value.


other (TensorDict or Tensor) – the other input tensordict or tensor.

Keyword Arguments:

default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

clamp_min_(other: TensorDictBase | torch.Tensor) T

In-place version of clamp_min().


inplace clamp_min does not support default keyword argument.

clear() T

Erases the content of the tensordict.

clear_device_() T

Clears the device of the tensordict.

Returns: self

clone(recurse: bool = True, **kwargs) T

Clones a TensorDictBase subclass instance onto a new TensorDictBase subclass of the same type.

To create a TensorDict instance from any other TensorDictBase subtype, call the to_tensordict() method instead.


recurse (bool, optional) – if True, each tensor contained in the TensorDict will be copied too. Otherwise only the TensorDict tree structure will be copied. Defaults to True.


Unlike many other ops (pointwise arithmetic, shape operations, …) clone does not inherit the original lock attribute. This design choice is made such that a clone can be created to be modified, which is the most frequent usage.


Closes the persistent tensordict.


Casts all tensors to torch.complex128.


Casts all tensors to torch.complex32.


Casts all tensors to torch.complex64.

consolidate(filename: Path | str | None = None, *, num_threads=0, device: torch.device | None = None, non_blocking: bool = False, inplace: bool = False, return_early: bool = False, use_buffer: bool = False, share_memory: bool = False, pin_memory: bool = False, metadata: bool = False) None

Consolidates the tensordict content in a single storage for fast serialization.


filename (Path, optional) – an optional file path for a memory-mapped tensor to use as a storage for the tensordict.

Keyword Arguments:
  • num_threads (integer, optional) – the number of threads to use for populating the storage.

  • device (torch.device, optional) – an optional device where the storage must be instantiated.

  • non_blocking (bool, optional) – non_blocking argument passed to copy_().

  • inplace (bool, optional) – if True, the resulting tensordict is the same as self with updated values. Defaults to False.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().

  • use_buffer (bool, optional) – if True and a filename is passed, an intermediate local buffer will be created in shared memory, and the data will be copied at the storage location as a last step. This may be faster than writing directly to a distant physical memory (e.g., NFS). Defaults to False.

  • share_memory (bool, optional) – if True, the storage will be placed in shared memory. Defaults to False.

  • pin_memory (bool, optional) – whether the consolidated data should be placed in pinned memory. Defaults to False.

  • metadata (bool, optional) – if True, the metadata will be stored alongisde the common storage. If a filename is provided, this is without effect. Storing the metadata can be useful when one wants to control how serialization is achieved, as TensorDict handles the pickling/unpickling of consolidated TDs differently if the metadata is or isn’t available.


If the tensordict is already consolidated, all arguments are ignored and self is returned. Call contiguous() to re-consolidate.


>>> import pickle
>>> import tempfile
>>> import torch
>>> import tqdm
>>> from torch.utils.benchmark import Timer
>>> from tensordict import TensorDict
>>> data = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> data_consolidated = data.consolidate()
>>> # check that the data has a single data_ptr()
>>> assert torch.tensor([
...     v.untyped_storage().data_ptr() for v in data_c.values(True, True)
... ]).unique().numel() == 1
>>> # Serializing the tensordict will be faster with data_consolidated
>>> with open("data.pickle", "wb") as f:
...    print("regular", Timer("pickle.dump(data, f)", globals=globals()).adaptive_autorange())
>>> with open("data_c.pickle", "wb") as f:
...     print("consolidated", Timer("pickle.dump(data_consolidated, f)", globals=globals()).adaptive_autorange())

Materializes a PersistentTensorDict on a regular TensorDict.


Return a shallow copy of the tensordict (ie, copies the structure but not the data).

Equivalent to TensorDictBase.clone(recurse=False)

copy_(tensordict: T, non_blocking: bool = False) T

See TensorDictBase.update_.

The non-blocking argument will be ignored and is just present for compatibility with torch.Tensor.copy_().

copy_at_(tensordict: T, idx: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], non_blocking: bool = False) T

See TensorDictBase.update_at_.

cos() T

Computes the cos() value of each element of the TensorDict.

cos_() T

Computes the cos() value of each element of the TensorDict in-place.

cosh() T

Computes the cosh() value of each element of the TensorDict.

cosh_() T

Computes the cosh() value of each element of the TensorDict in-place.

cpu(**kwargs) T

Casts a tensordict to CPU.

This function also supports all the keyword arguments of to().


Creates a nested tensordict of the same shape, device and dim names as the current tensordict.

If the value already exists, it will be overwritten by this operation. This operation is blocked in locked tensordicts.


>>> data = TensorDict({}, [3, 4, 5])
>>> data.create_nested("root")
>>> data.create_nested(("some", "nested", "value"))
>>> print(data)
        root: TensorDict(
            batch_size=torch.Size([3, 4, 5]),
        some: TensorDict(
                nested: TensorDict(
                        value: TensorDict(
                            batch_size=torch.Size([3, 4, 5]),
                    batch_size=torch.Size([3, 4, 5]),
            batch_size=torch.Size([3, 4, 5]),
    batch_size=torch.Size([3, 4, 5]),
cuda(device: Optional[int] = None, **kwargs) T

Casts a tensordict to a cuda device (if not already on it).


device (int, optional) – if provided, the cuda device on which the tensor should be cast.

This function also supports all the keyword arguments of to().

property data

Returns a tensordict containing the .data attributes of the leaf tensors.


Deletes a key of the tensordict.


key (NestedKey) – key to be deleted



property depth: int

Returns the depth - maximum number of levels - of a tensordict.

The minimum depth is 0 (no nested tensordict).

detach() T

Detach the tensors in the tensordict.


a new tensordict with no tensor requiring gradient.


Detach the tensors in the tensordict in-place.



property device

Device of a TensorDict.

If the TensorDict has a specified device, all its tensors (incl. nested ones) must live on the same device. If the TensorDict device is None, different values can be located on different devices.


torch.device object indicating the device where the tensors are placed, or None if TensorDict does not have a device.


>>> td = TensorDict({
...     "cpu": torch.randn(3, device='cpu'),
...     "cuda": torch.randn(3, device='cuda'),
... }, batch_size=[], device=None)
>>> td['cpu'].device
>>> td['cuda'].device
>>> td = TensorDict({
...     "x": torch.randn(3, device='cpu'),
...     "y": torch.randn(3, device='cuda'),
... }, batch_size=[], device='cuda')
>>> td['x'].device
>>> td['y'].device
>>> td = TensorDict({
...     "x": torch.randn(3, device='cpu'),
...     "y": TensorDict({'z': torch.randn(3, device='cpu')}, batch_size=[], device=None),
... }, batch_size=[], device='cuda')
>>> td['x'].device
>>> td['y'].device # nested tensordicts are also mapped onto the appropriate device.
>>> td['y', 'x'].device
dim() int

See batch_dims().

div(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T

Divides each element of the input self by the corresponding element of other.

\[\text{out}_i = \frac{\text{input}_i}{\text{other}_i}\]

Supports broadcasting, type promotion and integer, float, tensordict or tensor inputs. Always promotes integer types to the default scalar type.


other (TensorDict, Tensor or Number) – the divisor.

Keyword Arguments:

default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

div_(other: TensorDictBase | torch.Tensor) T

In-place version of div().


inplace div does not support default keyword argument.


Casts all tensors to torch.bool.

property dtype

Returns the dtype of the values in the tensordict, if it is unique.

dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Saves the tensordict to disk.

This function is a proxy to memmap().

empty(recurse=False, *, batch_size=None, device=_NoDefault.ZERO, names=None) T

Returns a new, empty tensordict with the same device and batch size.


recurse (bool, optional) – if True, the entire structure of the TensorDict will be reproduced without content. Otherwise, only the root will be duplicated. Defaults to False.

Keyword Arguments:
  • batch_size (torch.Size, optional) – a new batch-size for the tensordict.

  • device (torch.device, optional) – a new device.

  • names (list of str, optional) – dimension names.

entry_class(key: NestedKey) type

Returns the class of an entry, possibly avoiding a call to isinstance(td.get(key), type).

This method should be preferred to tensordict.get(key).shape whenever get() can be expensive to execute.

erf() T

Computes the erf() value of each element of the TensorDict.

erf_() T

Computes the erf() value of each element of the TensorDict in-place.

erfc() T

Computes the erfc() value of each element of the TensorDict.

erfc_() T

Computes the erfc() value of each element of the TensorDict in-place.

exclude(*keys: NestedKey, inplace: bool = False) T

Excludes the keys of the tensordict and returns a new tensordict without these entries.

The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.

  • *keys (str) – keys to exclude.

  • inplace (bool) – if True, the tensordict is pruned in place. Default is False.


A new tensordict (or the same if inplace=True) without the excluded entries.


>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, [])
>>> td.exclude("a", ("b", "c"))
        b: TensorDict(
                d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
>>> td.exclude("a", "b")
exp() T

Computes the exp() value of each element of the TensorDict.

exp_() T

Computes the exp() value of each element of the TensorDict in-place.

expand(*args, **kwargs) T

Expands each tensor of the tensordict according to the expand() function, ignoring the feature dimensions.

Supports iterables to specify the shape.


>>> td = TensorDict({
...     'a': torch.zeros(3, 4, 5),
...     'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4])
>>> td_expand = td.expand(10, 3, 4)
>>> assert td_expand.shape == torch.Size([10, 3, 4])
>>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5])
expand_as(other: TensorDictBase | torch.Tensor) TensorDictBase

Broadcasts the shape of the tensordict to the shape of other and expands it accordingly.

If the input is a tensor collection (tensordict or tensorclass), the leaves will be expanded on a one-to-one basis.


>>> from tensordict import TensorDict
>>> import torch
>>> td0 = TensorDict({
...     "a": torch.ones(3, 1, 4),
...     "b": {"c": torch.ones(3, 2, 1, 4)}},
...     batch_size=[3],
... )
>>> td1 = TensorDict({
...     "a": torch.zeros(2, 3, 5, 4),
...     "b": {"c": torch.zeros(2, 3, 2, 6, 4)}},
...     batch_size=[2, 3],
... )
>>> expanded = td0.expand_as(td1)
>>> assert (expanded==1).all()
>>> print(expanded)
        a: Tensor(shape=torch.Size([2, 3, 5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
                c: Tensor(shape=torch.Size([2, 3, 2, 6, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
    batch_size=torch.Size([2, 3]),
expm1() T

Computes the expm1() value of each element of the TensorDict.

expm1_() T

Computes the expm1() value of each element of the TensorDict in-place.

fill_(key: NestedKey, value: float | bool) TensorDictBase

Fills a tensor pointed by the key with the a given value.

  • key (str) – key to be remaned

  • value (Number, bool) – value to use for the filling




Filters out all empty tensordicts in-place.

filter_non_tensor_data() T

Filters out all non-tensor-data.

flatten(start_dim=0, end_dim=- 1)

Flattens all the tensors of a tensordict.

  • start_dim (int) – the first dim to flatten

  • end_dim (int) – the last dim to flatten


>>> td = TensorDict({
...     "a": torch.arange(60).view(3, 4, 5),
...     "b": torch.arange(12).view(3, 4)}, batch_size=[3, 4])
>>> td_flat = td.flatten(0, 1)
>>> td_flat.batch_size
>>> td_flat["a"]
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]])
>>> td_flat["b"]
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
flatten_keys(separator: str = '.', inplace: bool = False) T

Converts a nested tensordict into a flat one, recursively.

The TensorDict type will be lost and the result will be a simple TensorDict instance.

  • separator (str, optional) – the separator between the nested items.

  • inplace (bool, optional) – if True, the resulting tensordict will have the same identity as the one where the call has been made. Defaults to False.

  • is_leaf (callable, optional) – a callable over a class type returning a bool indicating if this class has to be considered as a leaf.


>>> data = TensorDict({"a": 1, ("b", "c"): 2, ("e", "f", "g"): 3}, batch_size=[])
>>> data.flatten_keys(separator=" - ")
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b - c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        e - f - g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},

This method and unflatten_keys() are particularly useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.


>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4))
>>> ddp_model =
>>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".")
>>> print(state_dict)
        module: TensorDict(
                0: TensorDict(
                        bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> model_state_dict = state_dict.get("module")
>>> print(model_state_dict)
        0: TensorDict(
                bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))

Casts all tensors to torch.float.


Casts all tensors to torch.float16.


Casts all tensors to torch.float32.


Casts all tensors to torch.float64.

floor() T

Computes the floor() value of each element of the TensorDict.

floor_() T

Computes the floor() value of each element of the TensorDict in-place.

frac() T

Computes the frac() value of each element of the TensorDict.

frac_() T

Computes the frac() value of each element of the TensorDict in-place.

classmethod from_dict(input_dict, filename, batch_size=None, device=None, **kwargs)

Converts a dictionary or a TensorDict to a h5 file.

  • input_dict (dict, TensorDict or compatible) – data to be stored as h5.

  • filename (str or path) – path to the h5 file.

  • batch_size (tensordict batch-size, optional) – if provided, batch size of the tensordict. If not, the batch size will be gathered from the input structure (if present) or determined automatically.

  • device (torch.device or compatible, optional) – the device where to expect the tensor once they are returned. Defaults to None (on cpu by default).

  • **kwargs – kwargs to be passed to h5py.File.create_dataset().


A PersitentTensorDict instance linked to the newly created file.

from_dict_instance(input_dict, batch_size=None, device=None, batch_dims=None, names=None)

Instance method version of from_dict().

Unlike from_dict(), this method will attempt to keep the tensordict types within the existing tree (for any existing leaf).


>>> from tensordict import TensorDict, tensorclass
>>> import torch
>>> @tensorclass
>>> class MyClass:
...     x: torch.Tensor
...     y: int
>>> td = TensorDict({"a": torch.randn(()), "b": MyClass(x=torch.zeros(()), y=1)})
>>> print(td.from_dict_instance(td.to_dict()))
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: MyClass(
            x=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
            y=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
>>> print(td.from_dict(td.to_dict()))
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
classmethod from_h5(filename, mode='r')

Creates a PersistentTensorDict from a h5 file.

This function will automatically determine the batch-size for each nested tensordict.

  • filename (str) – the path to the h5 file.

  • mode (str, optional) – reading mode. Defaults to "r".

classmethod from_module(module, as_module: bool = False, lock: bool = True, use_state_dict: bool = False)

Copies the params and buffers of a module in a tensordict.

  • module (nn.Module) – the module to get the parameters from.

  • as_module (bool, optional) – if True, a TensorDictParams instance will be returned which can be used to store parameters within a torch.nn.Module. Defaults to False.

  • lock (bool, optional) – if True, the resulting tensordict will be locked. Defaults to True.

  • use_state_dict (bool, optional) –

    if True, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to False. .. note:

    This is particularly useful when state-dict hooks have to be


>>> from torch import nn
>>> module = nn.TransformerDecoder(
...     decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
...     num_layers=1)
>>> params = TensorDict.from_module(module)
>>> print(params["layers", "0", "linear1"])
        bias: Parameter(shape=torch.Size([2048]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Parameter(shape=torch.Size([2048, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
classmethod from_modules(*modules, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, expand_identical: bool = False)

Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.


modules (sequence of nn.Module) – the modules to get the parameters from. If the modules differ in their structure, a lazy stack is needed (see the lazy_stack argument below).

Keyword Arguments:
  • as_module (bool, optional) – if True, a TensorDictParams instance will be returned which can be used to store parameters within a torch.nn.Module. Defaults to False.

  • lock (bool, optional) – if True, the resulting tensordict will be locked. Defaults to True.

  • use_state_dict (bool, optional) –

    if True, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to False. .. note:

    This is particularly useful when state-dict hooks have to be

  • lazy_stack (bool, optional) –

    whether parameters should be densly or lazily stacked. Defaults to False (dense stack).


    lazy_stack and as_module are exclusive features.


    There is a crucial difference between lazy and non-lazy outputs in that non-lazy output will reinstantiate parameters with the desired batch-size, while lazy_stack will just represent the parameters as lazily stacked. This means that whilst the original parameters can safely be passed to an optimizer when lazy_stack=True, the new parameters need to be passed when it is set to True.


    Whilst it can be tempting to use a lazy stack to keep the orignal parameter references, remember that lazy stack perform a stack each time get() is called. This will require memory (N times the size of the parameters, more if a graph is built) and time to be computed. It also means that the optimizer(s) will contain more parameters, and operations like step() or zero_grad() will take longer to be executed. In general, lazy_stack should be reserved to very few use cases.

  • expand_identical (bool, optional) – if True and the same parameter (same identity) is being stacked to itself, an expanded version of this parameter will be returned instead. This argument is ignored when lazy_stack=True.


>>> from torch import nn
>>> from tensordict import TensorDict
>>> torch.manual_seed(0)
>>> empty_module = nn.Linear(3, 4, device="meta")
>>> n_models = 2
>>> modules = [nn.Linear(3, 4) for _ in range(n_models)]
>>> params = TensorDict.from_modules(*modules)
>>> print(params)
        bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> # example of batch execution
>>> def exec_module(params, x):
...     with params.to_module(empty_module):
...         return empty_module(x)
>>> x = torch.randn(3)
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> # since lazy_stack = False, backprop leaves the original params untouched
>>> y.sum().backward()
>>> assert params["weight"].grad.norm() > 0
>>> assert modules[0].weight.grad is None

With lazy_stack=True, things are slightly different:

>>> params = TensorDict.from_modules(*modules, lazy_stack=True)
>>> print(params)
        bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> # example of batch execution
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> y.sum().backward()
>>> assert modules[0].weight.grad is not None
classmethod from_namedtuple(named_tuple, *, auto_batch_size: bool = False)

Converts a namedtuple to a TensorDict recursively.

Keyword Arguments:

auto_batch_size (bool, optional) – if True, the batch size will be computed automatically. Defaults to False.


>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "a_tensor": torch.zeros((3)),
...     "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3])
>>> nt = data.to_namedtuple()
>>> print(nt)
GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!'))
>>> TensorDict.from_namedtuple(nt, auto_batch_size=True)
        a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
                a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None),
                a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
classmethod from_pytree(pytree, *, batch_size: torch.Size | None = None, auto_batch_size: bool = False, batch_dims: int | None = None)

Converts a pytree to a TensorDict instance.

This method is designed to keep the pytree nested structure as much as possible.

Additional non-tensor keys are added to keep track of each level’s identity, providing a built-in pytree-to-tensordict bijective transform API.

Accepted classes currently include lists, tuples, named tuples and dict.


for dictionaries, non-NestedKey keys are registered separately as NonTensorData instances.


Tensor-castable types (such as int, float or np.ndarray) will be converted to torch.Tensor instances. NOte that this transformation is surjective: transforming back the tensordict to a pytree will not recover the original types.


>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key
>>> class WeirdLookingClass:
...     pass
>>> weird_key = WeirdLookingClass()
>>> # Make a pytree with tuple, lists, dict and namedtuple
>>> pytree = (
...     [torch.randint(10, (3,)), torch.zeros(2)],
...     {
...         "tensor": torch.randn(
...             2,
...         ),
...         "td": TensorDict({"one": 1}),
...         weird_key: torch.randint(10, (2,)),
...         "list": [1, 2, 3],
...     },
...     {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()},
... )
>>> # Build a TensorDict from that pytree
>>> td = TensorDict.from_pytree(pytree)
>>> # Recover the pytree
>>> pytree_recon = td.to_pytree()
>>> # Check that the leaves match
>>> def check(v1, v2):
>>>     assert (v1 == v2).all()
>>> torch.utils._pytree.tree_map(check, pytree, pytree_recon)
>>> assert weird_key in pytree_recon[1]
classmethod fromkeys(keys: List[NestedKey], value: Any = 0)

Creates a tensordict from a list of keys and a single value.

  • keys (list of NestedKey) – An iterable specifying the keys of the new dictionary.

  • value (compatible type, optional) – The value for all keys. Defaults to 0.

gather(dim: int, index: Tensor, out: T | None = None) T

Gathers values along an axis specified by dim.

  • dim (int) – the dimension along which collect the elements

  • index (torch.Tensor) – a long tensor which number of dimension matches the one of the tensordict with only one dimension differring between the two (the gathering dimension). Its elements refer to the index to be gathered along the required dimension.

  • out (TensorDictBase, optional) – a destination tensordict. It must have the same shape as the index.


>>> td = TensorDict(
...     {"a": torch.randn(3, 4, 5),
...      "b": TensorDict({"c": torch.zeros(3, 4, 5)}, [3, 4, 5])},
...     [3, 4])
>>> index = torch.randint(4, (3, 2))
>>> td_gather = td.gather(dim=1, index=index)
>>> print(td_gather)
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
                c: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 2, 5]),
    batch_size=torch.Size([3, 2]),

Gather keeps the dimension names.


>>> td.names = ["a", "b"]
>>> td_gather = td.gather(dim=1, index=index)
>>> td_gather.names
["a", "b"]
gather_and_stack(dst: int, group: 'dist.ProcessGroup' | None = None) T | None

Gathers tensordicts from various workers and stacks them onto self in the destination worker.

  • dst (int) – the rank of the destination worker where gather_and_stack() will be called.

  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.


>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>> import torch
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     # Create a single tensordict to be sent to server
...     td = TensorDict(
...         {("a", "b"): torch.randn(2),
...          "c": torch.randn(2)}, [2]
...     )
...     td.gather_and_stack(0)
>>> def server():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     # Creates the destination tensordict on server.
...     # The first dim must be equal to world_size-1
...     td = TensorDict(
...         {("a", "b"): torch.zeros(2),
...          "c": torch.zeros(2)}, [2]
...     ).expand(1, 2).contiguous()
...     td.gather_and_stack(0)
...     assert td["a", "b"] != 0
...     print("yuppie")
>>> if __name__ == "__main__":
...     mp.set_start_method("spawn")
...     main_worker = mp.Process(target=server)
...     secondary_worker = mp.Process(target=client)
...     main_worker.start()
...     secondary_worker.start()
...     main_worker.join()
...     secondary_worker.join()
get(key, default=_NoDefault.ZERO)

Gets the value stored with the input key.

  • key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.

  • default – default value if the key is not found in the tensordict.


>>> td = TensorDict({"x": 1}, batch_size=[])
>>> td.get("x")
>>> td.get("y", default=None)
get_at(key: NestedKey, idx: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], default: Tensor = _NoDefault.ZERO) Tensor

Get the value of a tensordict from the key key at the index idx.

  • key (str, tuple of str) – key to be retrieved.

  • index (int, slice, torch.Tensor, iterable) – index of the tensor.

  • default (torch.Tensor) – default value to return if the key is not present in the tensordict.


indexed tensor.


>>> td = TensorDict({"x": torch.arange(3)}, batch_size=[])
>>> td.get_at("x", index=1)
get_item_shape(key: NestedKey)

Returns the shape of the entry, possibly avoiding recurring to get().

get_non_tensor(key: NestedKey, default=_NoDefault.ZERO)

Gets a non-tensor value, if it exists, or default if the non-tensor value is not found.

This method is robust to tensor/TensorDict values, meaning that if the value gathered is a regular tensor it will be returned too (although this method comes with some overhead and should not be used out of its natural scope).

See set_non_tensor() for more information on how to set non-tensor values in a tensordict.

  • key (NestedKey) – the location of the NonTensorData object.

  • default (Any, optional) – the value to be returned if the key cannot be found.

Returns: the content of the tensordict.tensorclass.NonTensorData,

or the entry corresponding to the key if it isn’t a tensordict.tensorclass.NonTensorData (or default if the entry cannot be found).


>>> data = TensorDict({}, batch_size=[])
>>> data.set_non_tensor(("nested", "the string"), "a string!")
>>> assert data.get_non_tensor(("nested", "the string")) == "a string!"
>>> # regular `get` works but returns a NonTensorData object
>>> data.get(("nested", "the string"))
    data='a string!',
property grad

Returns a tensordict containing the .grad attributes of the leaf tensors.


Casts all tensors to torch.half.


Casts all tensors to


Casts all tensors to torch.int16.


Casts all tensors to torch.int32.


Casts all tensors to torch.int64.


Casts all tensors to torch.int8.

irecv(src: int, *, group: 'dist.ProcessGroup' | None = None, return_premature: bool = False, init_tag: int = 0, pseudo_rand: bool = False) tuple[int, list[torch.Future]] | list[torch.Future] | None

Receives the content of a tensordict and updates content with it asynchronously.

Check the example in the isend() method for context.


src (int) – the rank of the source worker.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • return_premature (bool) – if True, returns a list of futures to wait upon until the tensordict is updated. Defaults to False, i.e. waits until update is completed withing the call.

  • init_tag (int) – the init_tag used by the source worker.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to isend(). Defaults to False.


if return_premature=True, a list of futures to wait

upon until the tensordict is updated.


Checks if a TensorDict has a consolidated storage.


Returns a boolean indicating if all the tensors are contiguous.

is_empty() bool

Checks if the tensordict contains any leaf.

is_memmap() bool

Checks if tensordict is memory-mapped.

If a TensorDict instance is memory-mapped, it is locked (entries cannot be renamed, removed or added). If a TensorDict is created with tensors that are all memory-mapped, this does __not__ mean that is_memmap will return True (as a new tensor may or may not be memory-mapped). Only if one calls tensordict.memmap_() will the tensordict be considered as memory-mapped.

This is always True for tensordicts on a CUDA device.

is_shared() bool

Checks if tensordict is in shared memory.

If a TensorDict instance is in shared memory, it is locked (entries cannot be renamed, removed or added). If a TensorDict is created with tensors that are all in shared memory, this does __not__ mean that is_shared will return True (as a new tensor may or may not be in shared memory). Only if one calls tensordict.share_memory_() or places the tensordict on a device where the content is shared by default (eg, "cuda") will the tensordict be considered in shared memory.

This is always True for tensordicts on a CUDA device.

isend(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int

Sends the content of the tensordict asynchronously.


dst (int) – the rank of the destination worker where the content should be sent.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to False.


>>> import torch
>>> from tensordict import TensorDict
>>> from torch import multiprocessing as mp
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.randn(2),
...             "c": torch.randn(2, 3),
...             "_": torch.ones(2, 1, 5),
...         },
...         [2],
...     )
...     td.isend(0)
>>> def server(queue, return_premature=True):
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.zeros(2),
...             "c": torch.zeros(2, 3),
...             "_": torch.zeros(2, 1, 5),
...         },
...         [2],
...     )
...     out = td.irecv(1, return_premature=return_premature)
...     if return_premature:
...         for fut in out:
...             fut.wait()
...     assert (td != 0).all()
...     queue.put("yuppie")
>>> if __name__ == "__main__":
...     queue = mp.Queue(1)
...     main_worker = mp.Process(
...         target=server,
...         args=(queue, )
...         )
...     secondary_worker = mp.Process(target=client)
...     main_worker.start()
...     secondary_worker.start()
...     out = queue.get(timeout=10)
...     assert out == "yuppie"
...     main_worker.join()
...     secondary_worker.join()
isfinite() T

Returns a new tensordict with boolean elements representing if each element is finite or not.

Real values are finite when they are not NaN, negative infinity, or infinity. Complex values are finite when both their real and imaginary parts are finite.

isnan() T

Returns a new tensordict with boolean elements representing if each element of input is NaN or not.

Complex values are considered NaN when either their real and/or imaginary part is NaN.

isneginf() T

Tests if each element of input is negative infinity or not.

isposinf() T

Tests if each element of input is negative infinity or not.

isreal() T

Returns a new tensordict with boolean elements representing if each element of input is real-valued or not.

items(include_nested: bool = False, leaves_only: bool = False, is_leaf=None) Iterator[tuple[str, CompatibleType]]

Returns a generator of key-value pairs for the tensordict.

  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

keys(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) _PersistentTDKeysView

Returns a generator of tensordict keys.

  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.


>>> from tensordict import TensorDict
>>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[])
>>> data.keys()
['0', '1']
>>> list(data.keys(leaves_only=True))
>>> list(data.keys(include_nested=True, leaves_only=True))
['0', '1', ('1', '2')]
classmethod lazy_stack(input, dim=0, *, out=None, **kwargs)

Creates a lazy stack of tensordicts.

See lazy_stack() for details.

lerp(end: TensorDictBase | torch.Tensor, weight: TensorDictBase | torch.Tensor | float)

Does a linear interpolation of two tensors start (given by self) and end based on a scalar or tensor weight.

\[\text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i)\]

The shapes of start and end must be broadcastable. If weight is a tensor, then the shapes of weight, start, and end must be broadcastable.

  • end (TensorDict) – the tensordict with the ending points.

  • weight (TensorDict, tensor or float) – the weight for the interpolation formula.

lerp_(end: TensorDictBase | float, weight: TensorDictBase | float)

In-place version of lerp().

lgamma() T

Computes the lgamma() value of each element of the TensorDict.

lgamma_() T

Computes the lgamma() value of each element of the TensorDict in-place.

classmethod load(prefix: str | Path, *args, **kwargs) T

Loads a tensordict from disk.

This class method is a proxy to load_memmap().

load_(prefix: str | Path, *args, **kwargs)

Loads a tensordict from disk within the current tensordict.

This class method is a proxy to load_memmap_().

classmethod load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) T

Loads a memory-mapped tensordict from disk.

  • prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.

  • device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.

  • non_blocking (bool, optional) – if True, synchronize won’t be called after loading tensors on device. Defaults to False.

  • out (TensorDictBase, optional) – optional tensordict where the data should be written.


>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

This method also allows loading nested tensordicts.

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor:

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
fake: TensorDict(
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
load_memmap_(prefix: str | Path)

Loads the content of a memory-mapped tensordict within the tensordict where load_memmap_ is called.

See load_memmap() for more info.

load_state_dict(state_dict: OrderedDict[str, Any], strict=True, assign=False, from_flatten=False) T

Loads a state-dict, formatted as in state_dict(), into the tensordict.

  • state_dict (OrderedDict) – the state_dict of to be copied.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this tensordict’s torch.nn.Module.state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the tensordict instead of copying them inplace into the tensordict’s current tensors. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

  • from_flatten (bool, optional) – if True, the input state_dict is assumed to be flattened. Defaults to False.


>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
>>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
>>> sd = data.state_dict()
>>> data_zeroed.load_state_dict(sd)
>>> print(data_zeroed["3", "3"])
>>> # with flattening
>>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
>>> data_zeroed.load_state_dict(data.state_dict(flatten=True), from_flatten=True)
>>> print(data_zeroed["3", "3"])
lock_() T

Locks a tensordict for non in-place operations.

Functions such as set(), __setitem__(), update(), rename_key_() or other operations that add or remove entries will be blocked.

This method can be used as a decorator.


>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 1, "b": 2, "c": 3}, batch_size=[])
>>> with td.lock_():
...     assert td.is_locked
...     try:
...         td.set("d", 0) # error!
...     except RuntimeError:
...         print("td is locked!")
...     try:
...         del td["d"]
...     except RuntimeError:
...         print("td is locked!")
...     try:
...         td.rename_key_("a", "d")
...     except RuntimeError:
...         print("td is locked!")
...     td.set("a", 0, inplace=True)  # No storage is added, moved or removed
...     td.set_("a", 0) # No storage is added, moved or removed
...     td.update({"a": 0}, inplace=True)  # No storage is added, moved or removed
...     td.update_({"a": 0})  # No storage is added, moved or removed
>>> assert not td.is_locked
log() T

Computes the log() value of each element of the TensorDict.

log10() T

Computes the log10() value of each element of the TensorDict.

log10_() T

Computes the log10() value of each element of the TensorDict in-place.

log1p() T

Computes the log1p() value of each element of the TensorDict.

log1p_() T

Computes the log1p() value of each element of the TensorDict in-place.

log2() T

Computes the log2() value of each element of the TensorDict.

log2_() T

Computes the log2() value of each element of the TensorDict in-place.

log_() T

Computes the log() value of each element of the TensorDict in-place.

make_memmap(key: NestedKey, shape: torch.Size | torch.Tensor, *, dtype: torch.dtype | None = None) MemoryMappedTensor

Creates an empty memory-mapped tensor given a shape and possibly a dtype.


This method is not lock-safe by design. A memory-mapped TensorDict instance present on multiple nodes will need to be updated using the method memmap_refresh_().

Writing an existing entry will result in an error.

  • key (NestedKey) – the key of the new entry to write. If the key is already present in the tensordict, an exception is raised.

  • shape (torch.Size or equivalent, torch.Tensor for nested tensors) – the shape of the tensor to write.

Keyword Arguments:

dtype (torch.dtype, optional) – the dtype of the new tensor.


A new memory mapped tensor.

make_memmap_from_storage(key: NestedKey, storage: torch.UntypedStorage, shape: torch.Size | torch.Tensor, *, dtype: torch.dtype | None = None) MemoryMappedTensor

Creates an empty memory-mapped tensor given a storage, a shape and possibly a dtype.


This method is not lock-safe by design. A memory-mapped TensorDict instance present on multiple nodes will need to be updated using the method memmap_refresh_().


If the storage has a filename associated, it must match the new filename for the file. If it has not a filename associated but the tensordict has an associated path, this will result in an exception.

  • key (NestedKey) – the key of the new entry to write. If the key is already present in the tensordict, an exception is raised.

  • storage (torch.UntypedStorage) – the storage to use for the new MemoryMappedTensor. Must be a physical memory storage.

  • shape (torch.Size or equivalent, torch.Tensor for nested tensors) – the shape of the tensor to write.

Keyword Arguments:

dtype (torch.dtype, optional) – the dtype of the new tensor.


A new memory mapped tensor with the given storage.

make_memmap_from_tensor(key: NestedKey, tensor: Tensor, *, copy_data: bool = True) MemoryMappedTensor

Creates an empty memory-mapped tensor given a tensor.


This method is not lock-safe by design. A memory-mapped TensorDict instance present on multiple nodes will need to be updated using the method memmap_refresh_().

This method always copies the storage content if copy_data is True (i.e., the storage is not shared).

  • key (NestedKey) – the key of the new entry to write. If the key is already present in the tensordict, an exception is raised.

  • tensor (torch.Tensor) – the tensor to replicate on physical memory.

Keyword Arguments:

copy_data (bool, optionaL) – if False, the new tensor will share the metadata of the input such as shape and dtype, but the content will be empty. Defaults to True.


A new memory mapped tensor with the given storage.

map(fn: Callable, dim: int = 0, num_workers: int = None, *, out: TensorDictBase = None, chunksize: int = None, num_chunks: int = None, pool: mp.Pool = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = False, pbar: bool = False, mp_start_method: str | None = None)

Maps a function to splits of the tensordict across one dimension.

This method will apply a function to a tensordict instance by chunking it in tensordicts of equal size and dispatching the operations over the desired number of workers.

The function signature should be Callabe[[TensorDict], Union[TensorDict, Tensor]]. The output must support the operation. The function must be serializable.

  • fn (callable) – function to apply to the tensordict. Signatures similar to Callabe[[TensorDict], Union[TensorDict, Tensor]] are supported.

  • dim (int, optional) – the dim along which the tensordict will be chunked.

  • num_workers (int, optional) – the number of workers. Exclusive with pool. If none is provided, the number of workers will be set to the number of cpus available.

Keyword Arguments:
  • out (TensorDictBase, optional) – an optional container for the output. Its batch-size along the dim provided must match self.ndim. If it is shared or memmap (is_shared() or is_memmap() returns True) it will be populated within the remote processes, avoiding data inward transfers. Otherwise, the data from the self slice will be sent to the process, collected on the current process and written inplace into out.

  • chunksize (int, optional) – The size of each chunk of data. A chunksize of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, whereas chunksize>0 will split the tensordict and call on the resulting list of tensordicts. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with num_chunks.

  • num_chunks (int, optional) – the number of chunks to split the tensordict into. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with chunksize.

  • pool (mp.Pool, optional) – a multiprocess Pool instance to use to execute the job. If none is provided, a pool will be created within the map method.

  • generator (torch.Generator, optional) –

    a generator to use for seeding. A base seed will be generated from it, and each worker of the pool will be seeded with the provided seed incremented by a unique integer from 0 to num_workers. If no generator is provided, a random integer will be used as seed. To work with unseeded workers, a pool should be created separately and passed to map() directly. .. note:

    Caution should be taken when providing a low-valued seed as
    this can cause autocorrelation between experiments, example:
    if 8 workers are asked and the seed is 4, the workers seed will
    range from 4 to 11. If the seed is 5, the workers seed will range
    from 5 to 12. These two experiments will have an overlap of 7
    seeds, which can have unexpected effects on the results.


    The goal of seeding the workers is to have independent seed on each worker, and NOT to have reproducible results across calls of the map method. In other words, two experiments may and probably will return different results as it is impossible to know which worker will pick which job. However, we can make sure that each worker has a different seed and that the pseudo-random operations on each will be uncorrelated.

  • max_tasks_per_child (int, optional) – the maximum number of jobs picked by every child process. Defaults to None, i.e., no restriction on the number of jobs.

  • worker_threads (int, optional) – the number of threads for the workers. Defaults to 1.

  • index_with_generator (bool, optional) – if True, the splitting / chunking of the tensordict will be done during the query, sparing init time. Note that chunk() and split() are much more efficient than indexing (which is used within the generator) so a gain of processing time at init time may have a negative impact on the total runtime. Defaults to False.

  • pbar (bool, optional) – if True, a progress bar will be displayed. Requires tqdm to be available. Defaults to False.

  • mp_start_method (str, optional) – the start method for multiprocessing. If not provided, the default start method will be used. Accepted strings are "fork" and "spawn". Keep in mind that "cuda" tensors cannot be shared between processes with the "fork" start method. This is without effect if the pool is passed to the map method.


>>> import torch
>>> from tensordict import TensorDict
>>> def process_data(data):
...     data.set("y", data.get("x") + 1)
...     return data
>>> if __name__ == "__main__":
...     data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_()
...     data =, dim=1)
...     print(data["y"][:, :10])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])


This method is particularly useful when working with large datasets stored on disk (e.g. memory-mapped tensordicts) where chunks will be zero-copied slices of the original data which can be passed to the processes with virtually zero-cost. This allows to tread very large datasets (eg. over a Tb big) to be processed at little cost.

map_iter(fn: Callable[[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, shuffle: bool = False, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = True, pbar: bool = False, mp_start_method: str | None = None)

Maps a function to splits of the tensordict across one dimension iteratively.

This is the iterable version of map().

This method will apply a function to a tensordict instance by chunking it in tensordicts of equal size and dispatching the operations over the desired number of workers. It will yield the results one at a time.

The function signature should be Callabe[[TensorDict], Union[TensorDict, Tensor]]. The function must be serializable.

  • fn (callable) – function to apply to the tensordict. Signatures similar to Callabe[[TensorDict], Union[TensorDict, Tensor]] are supported.

  • dim (int, optional) – the dim along which the tensordict will be chunked.

  • num_workers (int, optional) – the number of workers. Exclusive with pool. If none is provided, the number of workers will be set to the number of cpus available.

Keyword Arguments:
  • shuffle (bool, optional) – whether the indices should be globally shuffled. If True, each batch will contain non-contiguous samples. If index_with_generator=False and shuffle=True`, an error will be raised. Defaults to False.

  • chunksize (int, optional) – The size of each chunk of data. A chunksize of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, whereas chunksize>0 will split the tensordict and call on the resulting list of tensordicts. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with num_chunks.

  • num_chunks (int, optional) – the number of chunks to split the tensordict into. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with chunksize.

  • pool (mp.Pool, optional) – a multiprocess Pool instance to use to execute the job. If none is provided, a pool will be created within the map method.

  • generator (torch.Generator, optional) –

    a generator to use for seeding. A base seed will be generated from it, and each worker of the pool will be seeded with the provided seed incremented by a unique integer from 0 to num_workers. If no generator is provided, a random integer will be used as seed. To work with unseeded workers, a pool should be created separately and passed to map() directly. .. note:

  • max_tasks_per_child (int, optional) – the maximum number of jobs picked by every child process. Defaults to None, i.e., no restriction on the number of jobs.

  • worker_threads (int, optional) – the number of threads for the workers. Defaults to 1.

  • index_with_generator (bool, optional) –

    if True, the splitting / chunking of the tensordict will be done during the query, sparing init time. Note that chunk() and split() are much more efficient than indexing (which is used within the generator) so a gain of processing time at init time may have a negative impact on the total runtime. Defaults to True.


    The default value of index_with_generator differs for map_iter and map and the former assumes that it is prohibitively expensive to store a split version of the TensorDict in memory.

  • pbar (bool, optional) – if True, a progress bar will be displayed. Requires tqdm to be available. Defaults to False.

  • mp_start_method (str, optional) – the start method for multiprocessing. If not provided, the default start method will be used. Accepted strings are "fork" and "spawn". Keep in mind that "cuda" tensors cannot be shared between processes with the "fork" start method. This is without effect if the pool is passed to the map method.


>>> import torch
>>> from tensordict import TensorDict
>>> def process_data(data):
...     data.unlock_()
...     data.set("y", data.get("x") + 1)
...     return data
>>> if __name__ == "__main__":
...     data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_()
...     for sample in data.map_iter(process_data, dim=1, chunksize=5):
...         print(sample["y"])
...         break
tensor([[1., 1., 1., 1., 1.]])


This method is particularly useful when working with large datasets stored on disk (e.g. memory-mapped tensordicts) where chunks will be zero-copied slices of the original data which can be passed to the processes with virtually zero-cost. This allows to tread very large datasets (eg. over a Tb big) to be processed at little cost.


This function be used to represent a dataset and load from it, in a dataloader-like fashion.

masked_fill(mask, value)

Out-of-place version of masked_fill.

  • mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.

  • value – value to used to fill the tensors.




>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...     batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td1 = td.masked_fill(mask, 1.0)
>>> td1.get("a")
tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
masked_fill_(mask, value)

Fills the values corresponding to the mask with the desired value.

  • mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.

  • value – value to used to fill the tensors.




>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...     batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td.masked_fill_(mask, 1.0)
>>> td.get("a")
tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
masked_select(mask: Tensor) T

Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values.


mask (torch.Tensor) – boolean mask to be used for the tensors. Shape must match the TensorDict batch_size.


>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...    batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td_mask = td.masked_select(mask)
>>> td_mask.get("a")
tensor([[0., 0., 0., 0.]])
maximum(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T

Computes the element-wise maximum of self and other.


other (TensorDict or Tensor) – the other input tensordict or tensor.

Keyword Arguments:

default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

maximum_(other: TensorDictBase | torch.Tensor) T

In-place version of maximum().


inplace maximum does not support default keyword argument.

classmethod maybe_dense_stack(input, dim=0, *, out=None, **kwargs)

Attempts to make a dense stack of tensordicts, and falls back on lazy stack when required..

See maybe_dense_stack() for details.

mean(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor

Returns the mean value of all elements in the input tensordict.

  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) – whether the output tensor has dim retained or not.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

  • reduce (bool, optional) – if True, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults to False.

memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.

  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.


A new tensordict with the tensors stored on disk if return_early=False, otherwise a TensorDictFuture instance.


Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_(prefix: str | None = None, copy_existing: bool = False, num_threads: int = 0) PersistentTensorDict

Writes all tensors onto a corresponding memory-mapped Tensor, in-place.

  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.


self if return_early=False, otherwise a TensorDictFuture instance.


Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_like(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Creates a contentless Memory-mapped tensordict with the same shapes as the original one.

  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.


A new TensorDict instance with data stored as memory-mapped tensors if return_early=False, otherwise a TensorDictFuture instance.


This is the recommended method to write a set of large buffers on disk, as memmap_() will copy the information, which can be slow for large content.


>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")

Refreshes the content of the memory-mapped tensordict if it has a saved_path.

This method will raise an exception if no path is associated with it.

minimum(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T

Computes the element-wise minimum of self and other.


other (TensorDict or Tensor) – the other input tensordict or tensor.

Keyword Arguments:

default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

minimum_(other: TensorDictBase | torch.Tensor) T

In-place version of minimum().


inplace minimum does not support default keyword argument.

mul(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T

Multiplies other to self.

\[\text{{out}}_i = \text{{input}}_i \times \text{{other}}_i\]

Supports broadcasting, type promotion, and integer, float, and complex inputs.


other (TensorDict, Tensor or Number) – the tensor or number to subtract from self.

Keyword Arguments:

default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

mul_(other: TensorDictBase | torch.Tensor) T

In-place version of mul().


inplace mul does not support default keyword argument.

named_apply(fn: Callable, *others: T, nested_keys: bool = False, batch_size: Sequence[int] | None = None, device: torch.device | None = _NoDefault.ZERO, names: Sequence[str] | None = _NoDefault.ZERO, inplace: bool = False, default: Any = _NoDefault.ZERO, filter_empty: bool | None = None, propagate_lock: bool = False, call_on_nested: bool = False, out: TensorDictBase | None = None, **constructor_kwargs) T | None

Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict.

The callable signature must be Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]].

  • fn (Callable) – function to be applied to the (name, tensor) pairs in the tensordict. For each leaf, only its leaf name will be used (not the full NestedKey).

  • *others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The fn argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through the default keyword argument.

  • nested_keys (bool, optional) – if True, the complete path to the leaf will be used. Defaults to False, i.e. only the last string is passed to the function.

  • batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The batch_size argument should match the batch_size after the transformation. This is a keyword only argument.

  • device (torch.device, optional) – the resulting device, if any.

  • names (list of str, optional) – the new dimension names, in case the batch_size is modified.

  • inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.

  • default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.

  • filter_empty (bool, optional) – if True, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Defaults to False for backward compatibility.

  • propagate_lock (bool, optional) – if True, a locked tensordict will produce another locked tensordict. Defaults to False.

  • call_on_nested (bool, optional) –

    if True, the function will be called on first-level tensors and containers (TensorDict or tensorclass). In this scenario, func is responsible of propagating its calls to nested levels. This allows a fine-grained behaviour when propagating the calls to nested tensordicts. If False, the function will only be called on leaves, and apply will take care of dispatching the function to all leaves.

    >>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]})
    >>> def mean_tensor_only(val):
    ...     if is_tensor_collection(val):
    ...         raise RuntimeError("Unexpected!")
    ...     return val.mean()
    >>> td_mean = td.apply(mean_tensor_only)
    >>> def mean_any(val):
    ...     if is_tensor_collection(val):
    ...         # Recurse
    ...         return val.apply(mean_any, call_on_nested=True)
    ...     return val.mean()
    >>> td_mean = td.apply(mean_any, call_on_nested=True)

  • out (TensorDictBase, optional) –

    a tensordict where to write the results. This can be used to avoid creating a new tensordict:

    >>> td = TensorDict({"a": 0})
    >>> td.apply(lambda x: x+1, out=td)
    >>> assert (td==1).all()


    If the operation executed on the tensordict requires multiple keys to be accessed for a single computation, providing an out argument equal to self can cause the operation to provide silently wrong results. For instance:

    >>> td = TensorDict({"a": 1, "b": 1})
    >>> td.apply(lambda x: x+td["a"])["b"] # Right!
    >>> td.apply(lambda x: x+td["a"], out=td)["b"] # Wrong!

  • **constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.


a new tensordict with transformed_in tensors.


>>> td = TensorDict({
...     "a": -torch.ones(3),
...     "nested": {"a": torch.ones(3), "b": torch.zeros(3)}},
...     batch_size=[3])
>>> def name_filter(name, tensor):
...     if name == "a":
...         return tensor
>>> td.named_apply(name_filter)
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
                a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> def name_filter(name, *tensors):
...     if name == "a":
...         r = 0
...         for tensor in tensors:
...             r = r + tensor
...         return tensor
>>> out = td.named_apply(name_filter, td)
>>> print(out)
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
                a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> print(out["a"])
tensor([-1., -1., -1.])


If None is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def name_filter(name, tensor):
...     if name == "1":
...         return tensor
>>> td.named_apply(name_filter)
        1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
                1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
property names

The dimension names of the tensordict.

The names can be set at construction time using the names argument.

See also refine_names() for details on how to set the names after construction.

nanmean(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor

Returns the mean of all non-NaN elements in the input tensordict.

  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) – whether the output tensor has dim retained or not.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

  • reduce (bool, optional) – if True, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults to False.

nansum(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor

Returns the sum of all non-NaN elements in the input tensordict.

  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) – whether the output tensor has dim retained or not.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

  • reduce (bool, optional) – if True, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults to False.

property ndim: int

See batch_dims().

ndimension() int

See batch_dims().

neg() T

Computes the neg() value of each element of the TensorDict.

neg_() T

Computes the neg() value of each element of the TensorDict in-place.

new_empty(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)

Returns a TensorDict of size size with emtpy tensors.

By default, the returned TensorDict has the same torch.dtype and torch.device as this tensordict.


size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if None, the torch.dtype will be unchanged.

  • device (torch.device, optional) – the desired device of returned tensordict. Default: if None, the torch.device will be unchanged.

  • requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default: False.

  • layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default: torch.strided.

  • pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.

new_full(size: Size, fill_value, *, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)

Returns a TensorDict of size size filled with 1.

By default, the returned TensorDict has the same torch.dtype and torch.device as this tensordict.

  • size (sequence of int) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

  • fill_value (scalar) – the number to fill the output tensor with.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if None, the torch.dtype will be unchanged.

  • device (torch.device, optional) – the desired device of returned tensordict. Default: if None, the torch.device will be unchanged.

  • requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default: False.

  • layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default: torch.strided.

  • pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.

new_ones(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)

Returns a TensorDict of size size filled with 1.

By default, the returned TensorDict has the same torch.dtype and torch.device as this tensordict.


size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if None, the torch.dtype will be unchanged.

  • device (torch.device, optional) – the desired device of returned tensordict. Default: if None, the torch.device will be unchanged.

  • requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default: False.

  • layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default: torch.strided.

  • pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.

new_tensor(data: torch.Tensor | TensorDictBase, *, dtype: torch.dtype = None, device: DeviceType = _NoDefault.ZERO, requires_grad: bool = False, pin_memory: bool | None = None)

Returns a new TensorDict with data as the tensor data.

By default, the returned TensorDict values have the same torch.dtype and torch.device as this tensor.

The data can also be a tensor collection (TensorDict or tensorclass), in which case the new_tensor method iterates over the tensor pairs of self and data.


data (torch.Tensor or TensorDictBase) – the data to be copied.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if None, the torch.dtype will be unchanged.

  • device (torch.device, optional) – the desired device of returned tensordict. Default: if None, the torch.device will be unchanged.

  • requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default: False.

  • pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.

new_zeros(*size: Size, dtype: Optional[dtype] = None, device: Union[device, str, int] = _NoDefault.ZERO, requires_grad: bool = False, layout: layout = torch.strided, pin_memory: Optional[bool] = None)

Returns a TensorDict of size size filled with 0.

By default, the returned TensorDict has the same torch.dtype and torch.device as this tensordict.


size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired type of returned tensordict. Default: if None, the torch.dtype will be unchanged.

  • device (torch.device, optional) – the desired device of returned tensordict. Default: if None, the torch.device will be unchanged.

  • requires_grad (bool, optional) – If autograd should record operations on the returned tensors. Default: False.

  • layout (torch.layout, optional) – the desired layout of returned TensorDict values. Default: torch.strided.

  • pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.

non_tensor_items(include_nested: bool = False)

Returns all non-tensor leaves, maybe recursively.

norm(*, out=None, dtype: torch.dtype | None = None)

Computes the norm of each tensor in the tensordict.

Keyword Arguments:
  • out (TensorDict, optional) – the output tensordict.

  • dtype (torch.dtype, optional) – the output dtype (torch>=2.4).

numel() int

Total number of elements in the batch.

Lower-bounded to 1, as a stack of two tensordict with empty shape will have two elements, therefore we consider that a tensordict is at least 1-element big.


Converts a tensordict to a (possibly nested) dictionary of numpy arrays.

Non-tensor data is exposed as such.


>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({"a": {"b": torch.zeros(()), "c": "a string!"}})
>>> print(data)
        a: TensorDict(
                b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                c: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)},
>>> print(data.numpy())
{'a': {'b': array(0., dtype=float32), 'c': 'a string!'}}
permute(*args, **kwargs)

Returns a view of a tensordict with the batch dimensions permuted according to dims.

  • *dims_list (int) – the new ordering of the batch dims of the tensordict. Alternatively, a single iterable of integers can be provided.

  • dims (list of int) – alternative way of calling permute(…).


a new tensordict with the batch dimensions in the desired order.


>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4])
>>> print(tensordict.permute([1, 0]))
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
    op=permute(dims=[1, 0]))
>>> print(tensordict.permute(1, 0))
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
    op=permute(dims=[1, 0]))
>>> print(tensordict.permute(dims=[1, 0]))
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
    op=permute(dims=[1, 0]))
pin_memory(*args, **kwargs)

Calls pin_memory() on the stored tensors.

  • num_threads (int or str) – if provided, the number of threads to use to call pin_memory on the leaves. Defaults to None, which sets a high number of threads in ThreadPoolExecutor(max_workers=None). To execute all the calls to pin_memory() on the main thread, pass num_threads=0.

  • inplace (bool, optional) – if True, the tensordict is modified in-place. Defaults to False.

pin_memory_(num_threads: int | str = 0) T

Calls pin_memory() on the stored tensors and returns the TensorDict modifies in-place.


num_threads (int or str) – if provided, the number of threads to use to call pin_memory on the leaves. If "auto" is passed, the number of threads is automatically determined.

pop(key: NestedKey, default: Any = _NoDefault.ZERO) Tensor

Removes and returns a value from a tensordict.

If the value is not present and no default value is provided, a KeyError is thrown.

  • key (str or nested key) – the entry to look for.

  • default (Any, optional) – the value to return if the key cannot be found.


>>> td = TensorDict({"1": 1}, [])
>>> one = td.pop("1")
>>> assert one == 1
>>> none = td.pop("1", default=None)
>>> assert none is None
popitem() Tuple[NestedKey, Tensor]

Removes the item that was last inserted into the TensorDict.

popitem will only return non-nested values.

pow(other: TensorDictBase | torch.Tensor, *, default: str | CompatibleType | None = None) T

Takes the power of each element in self with other and returns a tensor with the result.

other can be either a single float number, a Tensor or a TensorDict.

When other is a tensor, the shapes of input and other must be broadcastable.


other (float, tensor or tensordict) – the exponent value

Keyword Arguments:

default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

pow_(other: TensorDictBase | torch.Tensor) T

In-place version of pow().


inplace pow does not support default keyword argument.

prod(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor

Returns the produce of values of all elements in the input tensordict.

  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the prod value of all leaves (if this can be computed). If integer or tuple of integers, prod is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) – whether the output tensor has dim retained or not.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

  • reduce (bool, optional) – if True, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults to False.


Casts all tensors to torch.qint32.


Casts all tensors to torch.qint8.


Casts all tensors to torch.quint4x2.


Casts all tensors to torch.quint8.

reciprocal() T

Computes the reciprocal() value of each element of the TensorDict.

reciprocal_() T

Computes the reciprocal() value of each element of the TensorDict in-place.

recv(src: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int

Receives the content of a tensordict and updates content with it.

Check the example in the send method for context.


src (int) – the rank of the source worker.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the init_tag used by the source worker.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to send(). Defaults to False.

reduce(dst, op=None, async_op=False, return_premature=False, group=None)

Reduces the tensordict across all machines.

Only the process with rank dst is going to receive the final result.

refine_names(*names) T

Refines the dimension names of self according to names.

Refining is a special case of renaming that “lifts” unnamed dimensions. A None dim can be refined to have any name; a named dim can only be refined to have the same name.

Because named tensors can coexist with unnamed tensors, refining names gives a nice way to write named-tensor-aware code that works with both named and unnamed tensors.

names may contain up to one Ellipsis (…). The Ellipsis is expanded greedily; it is expanded in-place to fill names to the same length as self.dim() using names from the corresponding indices of self.names.

Returns: the same tensordict with dimensions named according to the input.


>>> td = TensorDict({}, batch_size=[3, 4, 5, 6])
>>> tdr = td.refine_names(None, None, None, "d")
>>> assert tdr.names == [None, None, None, "d"]
>>> tdr = td.refine_names("a", None, None, "d")
>>> assert tdr.names == ["a", None, None, "d"]
rename(*names, **rename_map)

Returns a clone of the tensordict with dimensions renamed.


>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4])
>>> td.names = list("abcd")
>>> td_rename = td.rename(c="g")
>>> assert td_rename.names == list("abgd")
rename_(*names, **rename_map)

Same as rename(), but executes the renaming in-place.


>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4])
>>> td.names = list("abcd")
>>> assert td.rename_(c="g")
>>> assert td.names == list("abgd")
rename_key_(old_key: NestedKey, new_key: NestedKey, safe: bool = False) PersistentTensorDict

Renames a key with a new string and returns the same tensordict with the updated key name.

  • old_key (str or nested key) – key to be renamed.

  • new_key (str or nested key) – new name of the entry.

  • safe (bool, optional) – if True, an error is thrown when the new key is already present in the TensorDict.



replace(*args, **kwargs)

Creates a shallow copy of the tensordict where entries have been replaced.

Accepts one unnamed argument which must be a dictionary of a TensorDictBase subclass. Additionally, first-level entries can be updated with the named keyword arguments.


a copy of self with updated entries if the input is non-empty. If an empty dict or no dict is provided and the kwargs are empty, self is returned.

requires_grad_(requires_grad=True) T

Change if autograd should record operations on this tensor: sets this tensor’s requires_grad attribute in-place.

Returns this tensordict.


requires_grad (bool, optional) – whether or not autograd should record operations on this tensordict. Defaults to True.

reshape(*args, **kwargs) T

Returns a contiguous, reshaped tensor of the desired shape.


*shape (int) – new shape of the resulting tensordict.


A TensorDict with reshaped keys


>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td = td.reshape(12)
>>> print(td['x'])
torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
round() T

Computes the round() value of each element of the TensorDict.

round_() T

Computes the round() value of each element of the TensorDict in-place.

save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Saves the tensordict to disk.

This function is a proxy to memmap().

property saved_path

Returns the path where a memmap saved TensorDict is being stored.

This argument valishes as soon as is_memmap() returns False (e.g., when the tensordict is unlocked).

select(*keys: NestedKey, inplace: bool = False, strict: bool = True) T

Selects the keys of the tensordict and returns a new tensordict with only the selected keys.

The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.

  • *keys (str) – keys to select

  • inplace (bool) – if True, the tensordict is pruned in place. Default is False.

  • strict (bool, optional) – whether selecting a key that is not present will return an error or not. Default: True.


A new tensordict (or the same if inplace=True) with the selected keys only.


To select keys in a tensordict and return a version of this tensordict deprived of these keys, see the split_keys() method.


>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, [])
>>>"a", ("b", "c"))
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
>>>"a", "b")
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
>>>"this key does not exist", strict=False)
send(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) None

Sends the content of a tensordict to a distant worker.


dst (int) – the rank of the destination worker where the content should be sent.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to False.


>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>> import torch
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.randn(2),
...             "c": torch.randn(2, 3),
...             "_": torch.ones(2, 1, 5),
...         },
...         [2],
...     )
...     td.send(0)
>>> def server(queue):
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.zeros(2),
...             "c": torch.zeros(2, 3),
...             "_": torch.zeros(2, 1, 5),
...         },
...         [2],
...     )
...     td.recv(1)
...     assert (td != 0).all()
...     queue.put("yuppie")
>>> if __name__=="__main__":
...     queue = mp.Queue(1)
...     main_worker = mp.Process(target=server, args=(queue,))
...     secondary_worker = mp.Process(target=client)
...     main_worker.start()
...     secondary_worker.start()
...     out = queue.get(timeout=10)
...     assert out == "yuppie"
...     main_worker.join()
...     secondary_worker.join()
set(key: NestedKey, item: Tensor, inplace: bool = False, *, non_blocking: bool = False, **kwargs: Any) T

Sets a new key-value pair.

  • key (str, tuple of str) – name of the key to be set.

  • item (torch.Tensor or equivalent, TensorDictBase instance) – value to be stored in the tensordict.

  • inplace (bool, optional) – if True and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If inplace is True and the entry cannot be found, it will be added. For a more restrictive in-place operation, use set_() instead. Defaults to False.

Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.




>>> td = TensorDict({}, batch_size[3, 4])
>>> td.set("x", torch.randn(3, 4))
>>> y = torch.randn(3, 4, 5)
>>> td.set("y", y, inplace=True) # works, even if 'y' is not present yet
>>> td.set("y", torch.zeros_like(y), inplace=True)
>>> assert (y==0).all() # y values are overwritten
>>> td.set("y", torch.ones(5), inplace=True) # raises an exception as shapes mismatch
set_(key: NestedKey, item: Tensor, *, non_blocking: bool = False) T

Sets a value to an existing key while keeping the original storage.

Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.




>>> td = TensorDict({}, batch_size[3, 4])
>>> x = torch.randn(3, 4)
>>> td.set("x", x)
>>> td.set_("x", torch.zeros_like(x))
>>> assert (x == 0).all()
set_at_(key: NestedKey, value: Tensor, index: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], *, non_blocking: bool = False) T

Sets the values in-place at the index indicated by index.

  • key (str, tuple of str) – key to be modified.

  • value (torch.Tensor) – value to be set at the index index

  • index (int, tensor or tuple) – index where to write the values.

Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.




>>> td = TensorDict({}, batch_size[3, 4])
>>> x = torch.randn(3, 4)
>>> td.set("x", x)
>>> td.set_at_("x", value=torch.ones(1, 4), index=slice(1))
>>> assert (x[0] == 1).all()
set_non_tensor(key: NestedKey, value: Any)

Registers a non-tensor value in the tensordict using tensordict.tensorclass.NonTensorData.

The value can be retrieved using TensorDictBase.get_non_tensor() or directly using get, which will return the tensordict.tensorclass.NonTensorData object.

return: self


>>> data = TensorDict({}, batch_size=[])
>>> data.set_non_tensor(("nested", "the string"), "a string!")
>>> assert data.get_non_tensor(("nested", "the string")) == "a string!"
>>> # regular `get` works but returns a NonTensorData object
>>> data.get(("nested", "the string"))
    data='a string!',
setdefault(key: NestedKey, default: Tensor, inplace: bool = False) Tensor

Insert the key entry with a value of default if key is not in the tensordict.

Return the value for key if key is in the tensordict, else default.

  • key (str or nested key) – the name of the value.

  • default (torch.Tensor or compatible type, TensorDictBase) – value to be stored in the tensordict if the key is not already present.


The value of key in the tensordict. Will be default if the key was not previously set.


>>> td = TensorDict({}, batch_size=[3, 4])
>>> val = td.setdefault("a", torch.zeros(3, 4))
>>> assert (val == 0).all()
>>> val = td.setdefault("a", torch.ones(3, 4))
>>> assert (val == 0).all() # output is still 0
property shape: Size

See batch_size.


Places all the tensors in shared memory.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Conversely, once the tensordict is unlocked, the share_memory attribute is turned to False, because cross-process identity is not guaranteed anymore.



sigmoid() T

Computes the sigmoid() value of each element of the TensorDict.

sigmoid_() T

Computes the sigmoid() value of each element of the TensorDict in-place.

sign() T

Computes the sign() value of each element of the TensorDict.

sign_() T

Computes the sign() value of each element of the TensorDict in-place.

sin() T

Computes the sin() value of each element of the TensorDict.

sin_() T

Computes the sin() value of each element of the TensorDict in-place.

sinh() T

Computes the sinh() value of each element of the TensorDict.

sinh_() T

Computes the sinh() value of each element of the TensorDict in-place.

size(dim: int | None = None) torch.Size | int

Returns the size of the dimension indicated by dim.

If dim is not specified, returns the batch_size attribute of the TensorDict.

property sorted_keys: list[NestedKey]

Returns the keys sorted in alphabetical order.

Does not support extra arguments.

If the TensorDict is locked, the keys are cached until the tensordict is unlocked for faster execution.

split(split_size: int | list[int], dim: int = 0) list[TensorDictBase]

Splits each tensor in the TensorDict with the specified size in the given dimension, like torch.split.

Returns a list of TensorDict instances with the view of split chunks of items.

  • split_size (int or List(int)) – size of a single chunk or list of sizes for each chunk.

  • dim (int) – dimension along which to split the tensor.


A list of TensorDict with specified size in given dimension.


>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td0, td1 = td.split([1, 2], dim=0)
>>> print(td0['x'])
torch.Tensor([[0, 1, 2, 3]])
split_keys(*key_sets, inplace=False, strict: bool = True, reproduce_struct: bool = False)

Splits the tensordict in subsets given one or more set of keys.

The method will return N+1 tensordicts, where N is the number of the arguments provided.

  • inplace (bool, optional) – if True, the keys are removed from self in-place. Defaults to False.

  • strict (bool, optional) – if True, an exception is raised when a key is missing. Defaults to True.

  • reproduce_struct (bool, optional) – if True, all tensordict returned have the same tree structure as self, even if some sub-tensordicts contain no leaves.


None non-tensor values will be ignored and not returned.


the method does not check for duplicates in the provided lists.


>>> td = TensorDict(
...     a=0,
...     b=0,
...     c=0,
...     d=0,
... )
>>> td_a, td_bc, td_d = td.split_keys(["a"], ["b", "c"])
>>> print(td_bc)

Computes the element-wise square root of self.


In-place version of sqrt().

squeeze(*args, **kwargs)

Squeezes all tensors for a dimension in between -self.batch_dims+1 and self.batch_dims-1 and returns them in a new tensordict.


dim (Optional[int]) – dimension along which to squeeze. If dim is None, all singleton dimensions will be squeezed. Defaults to None.


>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 1, 4, 2),
... }, batch_size=[3, 1, 4])
>>> td = td.squeeze()
>>> td.shape
torch.Size([3, 4])
>>> td.get("x").shape
torch.Size([3, 4, 2])

This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary). This functionality is not compatible with implicit squeezing.

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 1, 4, 2),
... }, batch_size=[3, 1, 4])
>>> with td.squeeze(1) as tds:
...     tds.set("y", torch.zeros(3, 4))
>>> assert td.get("y").shape == [3, 1, 4]
classmethod stack(input, dim=0, *, out=None)

Stacks tensordicts into a single tensordict along the given dimension.

This call is equivalent to calling torch.stack() but is compatible with torch.compile.

state_dict(destination=None, prefix='', keep_vars=False, flatten=False) OrderedDict[str, Any]

Produces a state_dict from the tensordict.

The structure of the state-dict will still be nested, unless flatten is set to True.

A tensordict state-dict contains all the tensors and meta-data needed to rebuild the tensordict (names are currently not supported).

  • destination (dict, optional) – If provided, the state of tensordict will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to tensor names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the torch.Tensor items returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

  • flatten (bool, optional) – whether the structure should be flattened with the "." character or not. Defaults to False.


>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
>>> sd = data.state_dict()
>>> print(sd)
OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3', OrderedDict([('3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])), ('__batch_size', torch.Size([])), ('__device', None)])
>>> sd = data.state_dict(flatten=True)
OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3.3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])
std(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, correction: int = 1, reduce: bool | None = None) TensorDictBase | torch.Tensor

Returns the standard deviation value of all elements in the input tensordict.

  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, std is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) – whether the output tensor has dim retained or not.

Keyword Arguments:
  • correction (int) – difference between the sample size and sample degrees of freedom. Defaults to Bessel’s correction, correction=1.

  • reduce (bool, optional) – if True, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults to False.

sub(other: TensorDictBase | float, *, alpha: float | None = None, default: str | CompatibleType | None = None)

Subtracts other, scaled by alpha, from self.

\[\text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i\]

Supports broadcasting, type promotion, and integer, float, and complex inputs.


other (TensorDict, Tensor or Number) – the tensor or number to subtract from self.

Keyword Arguments:
  • alpha (Number) – the multiplier for other.

  • default (torch.Tensor or str, optional) – the default value to use for exclusive entries. If none is provided, the two tensordicts key list must match exactly. If default="intersection" is passed, only the intersecting key sets will be considered and other keys will be ignored. In all other cases, default will be used for all missing entries on both sides of the operation.

sub_(other: TensorDictBase | float, alpha: float | None = None)

In-place version of sub().


inplace sub does not support default keyword argument.

sum(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, dtype: torch.dtype | None = None, reduce: bool | None = None) TensorDictBase | torch.Tensor

Returns the sum value of all elements in the input tensordict.

  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) – whether the output tensor has dim retained or not.

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

  • reduce (bool, optional) – if True, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults to False.

tan() T

Computes the tan() value of each element of the TensorDict.

tan_() T

Computes the tan() value of each element of the TensorDict in-place.

tanh() T

Computes the tanh() value of each element of the TensorDict.

tanh_() T

Computes the tanh() value of each element of the TensorDict in-place.

to(*args, **kwargs: Any) PersistentTensorDict

Maps a TensorDictBase subclass either on another device, dtype or to another TensorDictBase subclass (if permitted).

Casting tensors to a new dtype is not allowed, as tensordicts are not bound to contain a single tensor dtype.

  • device (torch.device, optional) – the desired device of the tensordict.

  • dtype (torch.dtype, optional) – the desired floating point or complex dtype of the tensordict.

  • tensor (torch.Tensor, optional) – Tensor whose dtype and device are the desired dtype and device for all tensors in this TensorDict.

Keyword Arguments:
  • non_blocking (bool, optional) – whether the operations should be blocking.

  • memory_format (torch.memory_format, optional) – the desired memory format for 4D parameters and buffers in this tensordict.

  • batch_size (torch.Size, optional) – resulting batch-size of the output tensordict.

  • other (TensorDictBase, optional) –

    TensorDict instance whose dtype and device are the desired dtype and device for all tensors in this TensorDict. .. note:: Since TensorDictBase instances do not have

    a dtype, the dtype is gathered from the example leaves. If there are more than one dtype, then no dtype casting is undertook.

  • non_blocking_pin (bool, optional) –

    if True, the tensors are pinned before being sent to device. This will be done asynchronously but can be controlled via the num_threads argument.


    Calling tensordict.pin_memory().to("cuda") will usually be much slower than"cuda", non_blocking_pin=True) as the pin_memory is called asynchronously in the second case. Multithreaded pin_memory will usually be beneficial if the tensors are large and numerous: when there are too few tensors to be sent, the overhead of spawning threads and collecting data outweighs the benefits of multithreading, and if the tensors are small the overhead of iterating over a long list is also prohibitively large.

  • num_threads (int or None, optional) – if non_blocking_pin=True, the number of threads to be used for pin_memory. By default, max(1, torch.get_num_threads()) threads will be spawn. num_threads=0 will cancel any multithreading for the pin_memory() calls.


a new tensordict instance if the device differs from the tensordict device and/or if the dtype is passed. The same tensordict otherwise. batch_size only modifications are done in-place.


if the TensorDict is consolidated, the resulting TensorDict will be consolidated too. Each new tensor will be a view on the consolidated storage cast to the desired device.


>>> data = TensorDict({"a": 1.0}, [], device=None)
>>> data_cuda ="cuda:0")  # casts to cuda
>>> data_int =  # casts to int
>>> data_cuda_int ="cuda:0",  # multiple casting
>>> data_cuda =, device="cuda:0"))  # using an example tensor
>>> data_cuda ={}, [], device="cuda:0"))  # using a tensordict example
to_dict() dict[str, Any]

Returns a dictionary with key-value pairs matching those of the tensordict.

to_h5(filename, **kwargs)

Converts a tensordict to a PersistentTensorDict with the h5 backend.

  • filename (str or path) – path to the h5 file.

  • device (torch.device or compatible, optional) – the device where to expect the tensor once they are returned. Defaults to None (on cpu by default).

  • **kwargs – kwargs to be passed to h5py.File.create_dataset().


A PersitentTensorDict instance linked to the newly created file.


>>> import tempfile
>>> import timeit
>>> from tensordict import TensorDict, MemoryMappedTensor
>>> td = TensorDict({
...     "a": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000)),
...     "b": {"c": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000, 3))},
... }, [1_000_000])
>>> file = tempfile.NamedTemporaryFile()
>>> td_h5 = td.to_h5(, compression="gzip", compression_opts=9)
>>> print(td_h5)
        a: Tensor(shape=torch.Size([1000000]), device=cpu, dtype=torch.float32, is_shared=False),
        b: PersistentTensorDict(
                c: Tensor(shape=torch.Size([1000000, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
to_module(module: nn.Module, *, inplace: bool | None = None, return_swap: bool = True, swap_dest=None, use_state_dict: bool = False, non_blocking: bool = False, memo=None)

Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.


module (nn.Module) – a module to write the parameters into.

Keyword Arguments:
  • inplace (bool, optional) – if True, the parameters or tensors in the module are updated in-place. Defaults to False.

  • return_swap (bool, optional) – if True, the old parameter configuration will be returned. Defaults to False.

  • swap_dest (TensorDictBase, optional) – if return_swap is True, the tensordict where the swap should be written.

  • use_state_dict (bool, optional) – if True, state-dict API will be used to load the parameters (including the state-dict hooks). Defaults to False.

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.


>>> from torch import nn
>>> module = nn.TransformerDecoder(
...     decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
...     num_layers=1)
>>> params = TensorDict.from_module(module)
>>> params.zero_()
>>> params.to_module(module)
>>> assert (module.layers[0].linear1.weight == 0).all()
to_namedtuple(dest_cls: type | None = None)

Converts a tensordict to a namedtuple.


dest_cls (Type, optional) – an optional namedtuple class to use.


>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "a_tensor": torch.zeros((3)),
...     "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3])
>>> data.to_namedtuple()
GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!'))
to_padded_tensor(padding=0.0, mask_key: NestedKey | None = None)

Converts all nested tensors to a padded version and adapts the batch-size accordingly.

  • padding (float) – the padding value for the tensors in the tensordict. Defaults to 0.0.

  • mask_key (NestedKey, optional) – if provided, the key where a mask for valid values will be written. Will result in an error if the heterogeneous dimension isn’t part of the tensordict batch-size. Defaults to None


Converts a tensordict to a PyTree.

If the tensordict was not created from a pytree, this method just returns self without modification.

See from_pytree() for more information and examples.

to_tensordict() T

Returns a regular TensorDict instance from the TensorDictBase.


a new TensorDict object containing the same values.

transpose(dim0, dim1)

Returns a tensordict that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

In-place or out-place modifications of the transposed tensordict will impact the original tensordict too as the memory is shared and the operations are mapped back on the original tensordict.


>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4])
>>> tensordict_transpose = tensordict.transpose(0, 1)
>>> print(tensordict_transpose.shape)
torch.Size([4, 3])
>>> tensordict_transpose.set("b",, torch.randn(4, 3))
>>> print(tensordict.get("b").shape)
torch.Size([3, 4])
trunc() T

Computes the trunc() value of each element of the TensorDict.

trunc_() T

Computes the trunc() value of each element of the TensorDict in-place.


Casts all tensors to dst_type.


dst_type (type or string) – the desired type


Casts all tensors to torch.uint16.


Casts all tensors to torch.uint32.


Casts all tensors to torch.uint64.


Casts all tensors to torch.uint8.

unbind(dim: int) tuple[T, ...]

Returns a tuple of indexed tensordicts, unbound along the indicated dimension.


>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td0, td1, td2 = td.unbind(0)
>>> td0['x']
tensor([0, 1, 2, 3])
>>> td1['x']
tensor([4, 5, 6, 7])
unflatten(dim, unflattened_size)

Unflattens a tensordict dim expanding it to a desired shape.

  • dim (int) – specifies the dimension of the input tensor to be unflattened.

  • unflattened_size (shape) – is the new shape of the unflattened dimension of the tensordict.


>>> td = TensorDict({
...     "a": torch.arange(60).view(3, 4, 5),
...     "b": torch.arange(12).view(3, 4)},
...     batch_size=[3, 4])
>>> td_flat = td.flatten(0, 1)
>>> td_unflat = td_flat.unflatten(0, [3, 4])
>>> assert (td == td_unflat).all()
unflatten_keys(separator: str = '.', inplace: bool = False) T

Converts a flat tensordict into a nested one, recursively.

The TensorDict type will be lost and the result will be a simple TensorDict instance. The metadata of the nested tensordicts will be inferred from the root: all instances across the data tree will share the same batch-size, dimension names and device.

  • separator (str, optional) – the separator between the nested items.

  • inplace (bool, optional) – if True, the resulting tensordict will have the same identity as the one where the call has been made. Defaults to False.


>>> data = TensorDict({"a": 1, "b - c": 2, "e - f - g": 3}, batch_size=[])
>>> data.unflatten_keys(separator=" - ")
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
        e: TensorDict(
                f: TensorDict(
                        g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},

This method and unflatten_keys() are particularly useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.


>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4))
>>> ddp_model =
>>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".")
>>> print(state_dict)
        module: TensorDict(
                0: TensorDict(
                        bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> model_state_dict = state_dict.get("module")
>>> print(model_state_dict)
        0: TensorDict(
                bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
unlock_() T

Unlocks a tensordict for non in-place operations.

Can be used as a decorator.

See lock_() for more details.

unsqueeze(*args, **kwargs)

Unsqueezes all tensors for a dimension comprised in between -td.batch_dims and td.batch_dims and returns them in a new tensordict.


dim (int) – dimension along which to unsqueeze


>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> td = td.unsqueeze(-2)
>>> td.shape
torch.Size([3, 1, 4])
>>> td.get("x").shape
torch.Size([3, 1, 4, 2])

This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary).

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> with td.unsqueeze(-2) as tds:
...     tds.set("y", torch.zeros(3, 1, 4))
>>> assert td.get("y").shape == [3, 4]
update(input_dict_or_td: dict[str, CompatibleType] | T, clone: bool = False, inplace: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None, is_leaf: Callable[[Type], bool] | None = None) T

Updates the TensorDict with values from either a dictionary or another TensorDict.

  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to False.

  • inplace (bool, optional) – if True and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If the entry cannot be found, it will be added. Defaults to False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated. This is aimed at avoiding calls to data_dest.update(*keys_to_update)).

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

  • is_leaf (Callable[[Type], bool], optional) – a callable that indicates whether an object type is to be considered a leaf and swapped or a tensor collection.




>>> td = TensorDict({}, batch_size=[3])
>>> a = torch.randn(3)
>>> b = torch.randn(3, 4)
>>> other_td = TensorDict({"a": a, "b": b}, batch_size=[])
>>> td.update(other_td, inplace=True) # writes "a" and "b" even though they can't be found
>>> assert td['a'] is other_td['a']
>>> other_td = other_td.clone().zero_()
>>> td.update(other_td)
>>> assert td['a'] is not other_td['a']
update_(input_dict_or_td: dict[str, CompatibleType] | T, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T

Updates the TensorDict in-place with values from either a dictionary or another TensorDict.

Unlike update(), this function will throw an error if the key is unknown to self.

  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated. This is aimed at avoiding calls to data_dest.update_(*keys_to_update)).

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.




>>> a = torch.randn(3)
>>> b = torch.randn(3, 4)
>>> td = TensorDict({"a": a, "b": b}, batch_size=[3])
>>> other_td = TensorDict({"a": a*0, "b": b*0}, batch_size=[])
>>> td.update_(other_td)
>>> assert td['a'] is not other_td['a']
>>> assert (td['a'] == other_td['a']).all()
>>> assert (td['a'] == 0).all()
update_at_(input_dict_or_td: dict[str, CompatibleType] | T, idx: IndexType, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T

Updates the TensorDict in-place at the specified index with values from either a dictionary or another TensorDict.

Unlike TensorDict.update, this function will throw an error if the key is unknown to the TensorDict.

  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • idx (int, torch.Tensor, iterable, slice) – index of the tensordict where the update should occur.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Default is False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated.

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.




>>> td = TensorDict({
...     'a': torch.zeros(3, 4, 5),
...     'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4])
>>> td.update_at_(
...     TensorDict({
...         'a': torch.ones(1, 4, 5),
...         'b': torch.ones(1, 4, 10)}, batch_size=[1, 4]),
...    slice(1, 2))
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 10]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
>>> assert (td[1] == 1).all()
values(include_nested: bool = False, leaves_only: bool = False, is_leaf=None) Iterator[Tensor]

Returns a generator representing the values for the tensordict.

  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

var(dim: int | Tuple[int] = _NoDefault.ZERO, keepdim: bool = _NoDefault.ZERO, *, correction: int = 1, reduce: bool | None = None) TensorDictBase | torch.Tensor

Returns the variance value of all elements in the input tensordict.

  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, var is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) – whether the output tensor has dim retained or not.

Keyword Arguments:
  • correction (int) – difference between the sample size and sample degrees of freedom. Defaults to Bessel’s correction, correction=1.

  • reduce (bool, optional) – if True, the reduciton will occur across all TensorDict values and a single reduced tensor will be returned. Defaults to False.

view(*shape: int, size: list | tuple | torch.Size | None = None, batch_size: torch.Size | None = None)

Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size.

Alternatively, a dtype can be provided as a first unnamed argument. In that case, all tensors will be viewed with the according dtype. Note that this assume that the new shapes will be compatible with the provided dtype. See view() for more information on dtype views.

  • *shape (int) – new shape of the resulting tensordict.

  • dtype (torch.dtype) – alternatively, a dtype to use to represent the tensor content.

  • size – iterable

Keyword Arguments:

batch_size (torch.Size, optional) – if a dtype is provided, the batch-size can be reset using this keyword argument. If the view is called with a shape, this is without effect.


a new tensordict with the desired batch_size.


>>> td = TensorDict(source={'a': torch.zeros(3,4,5),
...    'b': torch.zeros(3,4,10,1)}, batch_size=torch.Size([3, 4]))
>>> td_view = td.view(12)
>>> print(td_view.get("a").shape)  # torch.Size([12, 5])
>>> print(td_view.get("b").shape)  # torch.Size([12, 10, 1])
>>> td_view = td.view(-1, 4, 3)
>>> print(td_view.get("a").shape)  # torch.Size([1, 4, 3, 5])
>>> print(td_view.get("b").shape)  # torch.Size([1, 4, 3, 10, 1])
where(condition, other, *, out=None, pad=None)

Return a TensorDict of elements selected from either self or other, depending on condition.

  • condition (BoolTensor) – When True (nonzero), yields self, otherwise yields other.

  • other (TensorDictBase or Scalar) – value (if other is a scalar) or values selected at indices where condition is False.

Keyword Arguments:
  • out (TensorDictBase, optional) – the output TensorDictBase instance.

  • pad (scalar, optional) – if provided, missing keys from the source or destination tensordict will be written as torch.where(mask, self, pad) or torch.where(mask, pad, other). Defaults to None, ie missing keys are not tolerated.

zero_() T

Zeros all tensors in the tensordict in-place.

zero_grad(set_to_none: bool = True) T

Zeros all the gradients of the TensorDict recursively.


set_to_none (bool, optional) – if True, tensor.grad will be None, otherwise 0. Defaults to True.


