Shortcuts

TensorDict

class tensordict.TensorDict(source: T | dict[str, CompatibleType], batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, names: Sequence[str] | None = None, non_blocking: bool = None, lock: bool = False, _run_checks: bool = True)

A batched dictionary of tensors.

TensorDict is a tensor container where all tensors are stored in a key-value pair fashion and where each element shares the same first N leading dimensions shape, where is an arbitrary number with N >= 0.

Additionally, if the tensordict has a specified device, then each element must share that device.

TensorDict instances support many regular tensor operations with the notable exception of algebraic operations:

  • operations on shape: when a shape operation is called (indexing, reshape, view, expand, transpose, permute, unsqueeze, squeeze, masking etc), the operations is done as if it was executed on a tensor of the same shape as the batch size then expended to the right, e.g.:

    >>> td = TensorDict({'a': torch.zeros(3, 4, 5)}, batch_size=[3, 4])
    >>> # returns a TensorDict of batch size [3, 4, 1]:
    >>> td_unsqueeze = td.unsqueeze(-1)
    >>> # returns a TensorDict of batch size [12]
    >>> td_view = td.view(-1)
    >>> # returns a tensor of batch size [12, 4]
    >>> a_view = td.view(-1).get("a")
    
  • casting operations: a TensorDict can be cast on a different device using

    >>> td_cpu = td.to("cpu")
    >>> dictionary = td.to_dict()
    

    A call of the .to() method with a dtype will return an error.

  • Cloning (clone()), contiguous (contiguous());

  • Reading: td.get(key), td.get_at(key, index)

  • Content modification: td.set(key, value), td.set_(key, value), td.update(td_or_dict), td.update_(td_or_dict), td.fill_(key, value), td.rename_key_(old_name, new_name), etc.

  • Operations on multiple tensordicts: torch.cat(tensordict_list, dim), torch.stack(tensordict_list, dim), td1 == td2, td.apply(lambda x+y, other_td) etc.

Parameters:
  • source (TensorDict or Dict[NestedKey, Union[Tensor, TensorDictBase]]) – a data source. If empty, the tensordict can be populated subsequently.

  • batch_size (iterable of int, optional) – a batch size for the tensordict. The batch size can be modified subsequently as long as it is compatible with its content. If not batch-size is provided, an empty batch-size is assumed (it is not inferred automatically from the data). To automatically set the batch-size, refer to auto_batch_size_().

  • device (torch.device or compatible type, optional) – a device for the TensorDict. If provided, all tensors will be stored on that device. If not, tensors on different devices are allowed.

  • names (lsit of str, optional) – the names of the dimensions of the tensordict. If provided, its length must match the one of the batch_size. Defaults to None (no dimension name, or None for every dimension).

  • non_blocking (bool, optional) – if True and a device is passed, the tensordict is delivered without synchronization. This is the fastest option but is only safe when casting from cpu to cuda (otherwise a synchronization call must be implemented by the user). If False is passed, every tensor movement will be done synchronously. If None (default), the device casting will be done asynchronously but a synchronization will be executed after creation if required. This option should generally be faster than False and potentially slower than True.

  • lock (bool, optional) – if True, the resulting tensordict will be locked.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> source = {'random': torch.randn(3, 4),
...     'zeros': torch.zeros(3, 4, 5)}
>>> batch_size = [3]
>>> td = TensorDict(source, batch_size=batch_size)
>>> print(td.shape)  # equivalent to td.batch_size
torch.Size([3])
>>> td_unqueeze = td.unsqueeze(-1)
>>> print(td_unqueeze.get("zeros").shape)
torch.Size([3, 1, 4, 5])
>>> print(td_unqueeze[0].shape)
torch.Size([1])
>>> print(td_unqueeze.view(-1).shape)
torch.Size([3])
>>> print((td.clone()==td).all())
True
all(dim: int = None) bool | TensorDictBase

Checks if all values are True/non-null in the tensordict.

Parameters:

dim (int, optional) – if None, returns a boolean indicating whether all tensors return tensor.all() == True If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

any(dim: int = None) bool | TensorDictBase

Checks if any value is True/non-null in the tensordict.

Parameters:

dim (int, optional) – if None, returns a boolean indicating whether all tensors return tensor.any() == True. If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

apply(fn: Callable, *others: T, batch_size: Sequence[int] | None = None, device: torch.device | None = <tensordict.base._NoDefault object>, names: Sequence[str] | None = None, inplace: bool = False, default: Any = <tensordict.base._NoDefault object>, filter_empty: bool | None = None, propagate_lock: bool = False, **constructor_kwargs) T | None

Applies a callable to all values stored in the tensordict and sets them in a new tensordict.

The callable signature must be Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]].

Parameters:
  • fn (Callable) – function to be applied to the tensors in the tensordict.

  • *others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The fn argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through the default keyword argument.

Keyword Arguments:
  • batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The batch_size argument should match the batch_size after the transformation. This is a keyword only argument.

  • device (torch.device, optional) – the resulting device, if any.

  • names (list of str, optional) – the new dimension names, in case the batch_size is modified.

  • inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.

  • default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.

  • filter_empty (bool, optional) – if True, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Non-tensor data is considered as a leaf and thereby will be kept in the tensordict even if left untouched by the function. Defaults to False for backward compatibility.

  • propagate_lock (bool, optional) – if True, a locked tensordict will produce another locked tensordict. Defaults to False.

  • **constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.

Returns:

a new tensordict with transformed_in tensors.

Example

>>> td = TensorDict({
...     "a": -torch.ones(3),
...     "b": {"c": torch.ones(3)}},
...     batch_size=[3])
>>> td_1 = td.apply(lambda x: x+1)
>>> assert (td_1["a"] == 0).all()
>>> assert (td_1["b", "c"] == 2).all()
>>> td_2 = td.apply(lambda x, y: x+y, td)
>>> assert (td_2["a"] == -2).all()
>>> assert (td_2["b", "c"] == 2).all()

Note

If None is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def filter(tensor):
...     if tensor == 1:
...         return tensor
>>> td.apply(filter)
TensorDict(
    fields={
        1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Note

The apply method will return an TensorDict instance, regardless of the input type. To keep the same type, one can execute

>>> out = td.clone(False).update(td.apply(...))
apply_(fn: Callable, *others, **kwargs) T

Applies a callable to all values stored in the tensordict and re-writes them in-place.

Parameters:
  • fn (Callable) – function to be applied to the tensors in the tensordict.

  • *others (sequence of TensorDictBase, optional) – the other tensordicts to be used.

Keyword Args: See apply().

Returns:

self or a copy of self with the function applied

auto_batch_size_(batch_dims: int | None = None) T

Sets the maximum batch-size for the tensordict, up to an optional batch_dims.

Parameters:

batch_dims (int, optional) – if provided, the batch-size will be at most batch_dims long.

Returns:

self

Examples

>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[])
>>> td.auto_batch_size_()
>>> print(td.batch_size)
torch.Size([3, 4])
>>> td.auto_batch_size_(batch_dims=1)
>>> print(td.batch_size)
torch.Size([3])
property batch_dims: int

Length of the tensordict batch size.

Returns:

int describing the number of dimensions of the tensordict.

property batch_size: Size

Shape (or batch_size) of a TensorDict.

The shape of a tensordict corresponds to the common first N dimensions of the tensors it contains, where N is an arbitrary number. The TensorDict shape is controlled by the user upon initialization (ie, it is not inferred from the tensor shapes).

The batch_size can be edited dynamically if the new size is compatible with the TensorDict content. For instance, setting the batch size to an empty value is always allowed.

Returns:

a Size object describing the TensorDict batch size.

Examples

>>> data = TensorDict({
...     "key 0": torch.randn(3, 4),
...     "key 1": torch.randn(3, 5),
...     "nested": TensorDict({"key 0": torch.randn(3, 4)}, batch_size=[3, 4])},
...     batch_size=[3])
>>> data.batch_size = () # resets the batch-size to an empty value
bfloat16()

Casts all tensors to torch.bfloat16.

bool()

Casts all tensors to torch.bool.

chunk(chunks: int, dim: int = 0) tuple[TensorDictBase, ...]

Splits a tensordict into the specified number of chunks, if possible.

Each chunk is a view of the input tensordict.

Parameters:
  • chunks (int) – number of chunks to return

  • dim (int, optional) – dimension along which to split the tensordict. Default is 0.

Examples

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> td0, td1 = td.chunk(dim=-1, chunks=2)
>>> td0['x']
tensor([[[ 0,  1],
         [ 2,  3]],
        [[ 8,  9],
         [10, 11]],
        [[16, 17],
         [18, 19]]])
clear() T

Erases the content of the tensordict.

clear_device_() T

Clears the device of the tensordict.

Returns: self

clone(recurse: bool = True, **kwargs) T

Clones a TensorDictBase subclass instance onto a new TensorDictBase subclass of the same type.

To create a TensorDict instance from any other TensorDictBase subtype, call the to_tensordict() method instead.

Parameters:

recurse (bool, optional) – if True, each tensor contained in the TensorDict will be copied too. Otherwise only the TensorDict tree structure will be copied. Defaults to True.

Note

Unlike many other ops (pointwise arithmetic, shape operations, …) clone does not inherit the original lock attribute. This design choice is made such that a clone can be created to be modified, which is the most frequent usage.

contiguous() T

Returns a new tensordict of the same type with contiguous values (or self if values are already contiguous).

copy()

Return a shallow copy of the tensordict (ie, copies the structure but not the data).

Equivalent to TensorDictBase.clone(recurse=False)

copy_(tensordict: T, non_blocking: bool = False) T

See TensorDictBase.update_.

The non-blocking argument will be ignored and is just present for compatibility with torch.Tensor.copy_().

copy_at_(tensordict: T, idx: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], non_blocking: bool = False) T

See TensorDictBase.update_at_.

cpu() T

Casts a tensordict to CPU.

create_nested(key)

Creates a nested tensordict of the same shape, device and dim names as the current tensordict.

If the value already exists, it will be overwritten by this operation. This operation is blocked in locked tensordicts.

Examples

>>> data = TensorDict({}, [3, 4, 5])
>>> data.create_nested("root")
>>> data.create_nested(("some", "nested", "value"))
>>> print(data)
TensorDict(
    fields={
        root: TensorDict(
            fields={
            },
            batch_size=torch.Size([3, 4, 5]),
            device=None,
            is_shared=False),
        some: TensorDict(
            fields={
                nested: TensorDict(
                    fields={
                        value: TensorDict(
                            fields={
                            },
                            batch_size=torch.Size([3, 4, 5]),
                            device=None,
                            is_shared=False)},
                    batch_size=torch.Size([3, 4, 5]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([3, 4, 5]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3, 4, 5]),
    device=None,
    is_shared=False)
cuda(device: Optional[int] = None) T

Casts a tensordict to a cuda device (if not already on it).

Parameters:

device (int, optional) – if provided, the cuda device on which the tensor should be cast.

property data

Returns a tensordict containing the .data attributes of the leaf tensors.

del_(key: Union[str, Tuple[str, ...]]) T

Deletes a key of the tensordict.

Parameters:

key (NestedKey) – key to be deleted

Returns:

self

property depth: int

Returns the depth - maximum number of levels - of a tensordict.

The minimum depth is 0 (no nested tensordict).

detach() T

Detach the tensors in the tensordict.

Returns:

a new tensordict with no tensor requiring gradient.

detach_() T

Detach the tensors in the tensordict in-place.

Returns:

self.

property device: torch.device | None

Device of the tensordict.

Returns None if device hasn’t been provided in the constructor or set via tensordict.to(device).

dim() int

See batch_dims().

double()

Casts all tensors to torch.bool.

property dtype

Returns the dtype of the values in the tensordict, if it is unique.

dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Saves the tensordict to disk.

This function is a proxy to memmap().

empty(recurse=False) T

Returns a new, empty tensordict with the same device and batch size.

Parameters:

recurse (bool, optional) – if True, the entire structure of the TensorDict will be reproduced without content. Otherwise, only the root will be duplicated. Defaults to False.

entry_class(key: Union[str, Tuple[str, ...]]) type

Returns the class of an entry, possibly avoiding a call to isinstance(td.get(key), type).

This method should be preferred to tensordict.get(key).shape whenever get() can be expensive to execute.

exclude(*keys: Union[str, Tuple[str, ...]], inplace: bool = False) T

Excludes the keys of the tensordict and returns a new tensordict without these entries.

The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.

Parameters:
  • *keys (str) – keys to exclude.

  • inplace (bool) – if True, the tensordict is pruned in place. Default is False.

Returns:

A new tensordict (or the same if inplace=True) without the excluded entries.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, [])
>>> td.exclude("a", ("b", "c"))
TensorDict(
    fields={
        b: TensorDict(
            fields={
                d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> td.exclude("a", "b")
TensorDict(
    fields={
    },
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
expand(*args, **kwargs) T

Expands each tensor of the tensordict according to the expand() function, ignoring the feature dimensions.

Supports iterables to specify the shape.

Examples

>>> td = TensorDict({
...     'a': torch.zeros(3, 4, 5),
...     'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4])
>>> td_expand = td.expand(10, 3, 4)
>>> assert td_expand.shape == torch.Size([10, 3, 4])
>>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5])
fill_(key: NestedKey, value: float | bool) T

Fills a tensor pointed by the key with a given scalar value.

Parameters:
  • key (str or nested key) – entry to be filled.

  • value (Number or bool) – value to use for the filling.

Returns:

self

filter_non_tensor_data() T

Filters out all non-tensor-data.

flatten(start_dim=0, end_dim=- 1)

Flattens all the tensors of a tensordict.

Parameters:
  • start_dim (int) – the first dim to flatten

  • end_dim (int) – the last dim to flatten

Examples

>>> td = TensorDict({
...     "a": torch.arange(60).view(3, 4, 5),
...     "b": torch.arange(12).view(3, 4)}, batch_size=[3, 4])
>>> td_flat = td.flatten(0, 1)
>>> td_flat.batch_size
torch.Size([12])
>>> td_flat["a"]
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]])
>>> td_flat["b"]
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
flatten_keys(separator: str = '.', inplace: bool = False, is_leaf: Callable[[Type], bool] | None = None) T

Converts a nested tensordict into a flat one, recursively.

The TensorDict type will be lost and the result will be a simple TensorDict instance.

Parameters:
  • separator (str, optional) – the separator between the nested items.

  • inplace (bool, optional) – if True, the resulting tensordict will have the same identity as the one where the call has been made. Defaults to False.

  • is_leaf (callable, optional) – a callable over a class type returning a bool indicating if this class has to be considered as a leaf.

Examples

>>> data = TensorDict({"a": 1, ("b", "c"): 2, ("e", "f", "g"): 3}, batch_size=[])
>>> data.flatten_keys(separator=" - ")
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b - c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        e - f - g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This method and unflatten_keys() are particularily useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.

Examples

>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4))
>>> ddp_model = torch.ao.quantization.QuantWrapper(model)
>>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".")
>>> print(state_dict)
TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model_state_dict = state_dict.get("module")
>>> print(model_state_dict)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
float()

Casts all tensors to torch.float.

classmethod from_dict(input_dict, batch_size=None, device=None, batch_dims=None)

Returns a TensorDict created from a dictionary or another TensorDict.

If batch_size is not specified, returns the maximum batch size possible.

This function works on nested dictionaries too, or can be used to determine the batch-size of a nested tensordict.

Parameters:
  • input_dict (dictionary, optional) – a dictionary to use as a data source (nested keys compatible).

  • batch_size (iterable of int, optional) – a batch size for the tensordict.

  • device (torch.device or compatible type, optional) – a device for the TensorDict.

  • batch_dims (int, optional) – the batch_dims (ie number of leading dimensions to be considered for batch_size). Exclusinve with batch_size. Note that this is the __maximum__ number of batch dims of the tensordict, a smaller number is tolerated.

Examples

>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
>>> print(TensorDict.from_dict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # nested dict: the nested TensorDict can have a different batch-size
>>> # as long as its leading dims match.
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
>>> print(TensorDict.from_dict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # we can also use this to work out the batch sie of a tensordict
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
>>> print(TensorDict.from_dict(input_td))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
from_dict_instance(input_dict, batch_size=None, device=None, batch_dims=None)

Instance method version of from_dict().

Unlike from_dict(), this method will attempt to keep the tensordict types within the existing tree (for any existing leaf).

Examples

>>> from tensordict import TensorDict, tensorclass
>>> import torch
>>>
>>> @tensorclass
>>> class MyClass:
...     x: torch.Tensor
...     y: int
>>>
>>> td = TensorDict({"a": torch.randn(()), "b": MyClass(x=torch.zeros(()), y=1)})
>>> print(td.from_dict_instance(td.to_dict()))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: MyClass(
            x=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
            y=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(td.from_dict(td.to_dict()))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
classmethod from_module(module: Module, as_module: bool = False, lock: bool = False, use_state_dict: bool = False)

Copies the params and buffers of a module in a tensordict.

Parameters:
  • module (nn.Module) – the module to get the parameters from.

  • as_module (bool, optional) – if True, a TensorDictParams instance will be returned which can be used to store parameters within a torch.nn.Module. Defaults to False.

  • lock (bool, optional) – if True, the resulting tensordict will be locked. Defaults to True.

  • use_state_dict (bool, optional) –

    if True, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to False. .. note:

    This is particularly useful when state-dict hooks have to be
    used.
    

Examples

>>> from torch import nn
>>> module = nn.TransformerDecoder(
...     decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
...     num_layers=1)
>>> params = TensorDict.from_module(module)
>>> print(params["layers", "0", "linear1"])
TensorDict(
    fields={
        bias: Parameter(shape=torch.Size([2048]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Parameter(shape=torch.Size([2048, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
classmethod from_modules(*modules, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False)

Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.

Parameters:

modules (sequence of nn.Module) – the modules to get the parameters from. If the modules differ in their structure, a lazy stack is needed (see the lazy_stack argument below).

Keyword Arguments:
  • as_module (bool, optional) – if True, a TensorDictParams instance will be returned which can be used to store parameters within a torch.nn.Module. Defaults to False.

  • lock (bool, optional) – if True, the resulting tensordict will be locked. Defaults to True.

  • use_state_dict (bool, optional) –

    if True, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to False. .. note:

    This is particularly useful when state-dict hooks have to be
    used.
    

  • lazy_stack (bool, optional) –

    whether parameters should be densly or lazily stacked. Defaults to False (dense stack).

    Note

    lazy_stack and as_module are exclusive features.

    Warning

    There is a crucial difference between lazy and non-lazy outputs in that non-lazy output will reinstantiate parameters with the desired batch-size, while lazy_stack will just represent the parameters as lazily stacked. This means that whilst the original parameters can safely be passed to an optimizer when lazy_stack=True, the new parameters need to be passed when it is set to True.

    Warning

    Whilst it can be tempting to use a lazy stack to keep the orignal parameter references, remember that lazy stack perform a stack each time get() is called. This will require memory (N times the size of the parameters, more if a graph is built) and time to be computed. It also means that the optimizer(s) will contain more parameters, and operations like step() or zero_grad() will take longer to be executed. In general, lazy_stack should be reserved to very few use cases.

Examples

>>> from torch import nn
>>> from tensordict import TensorDict
>>> torch.manual_seed(0)
>>> empty_module = nn.Linear(3, 4, device="meta")
>>> n_models = 2
>>> modules = [nn.Linear(3, 4) for _ in range(n_models)]
>>> params = TensorDict.from_modules(*modules)
>>> print(params)
TensorDict(
    fields={
        bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
>>> # example of batch execution
>>> def exec_module(params, x):
...     with params.to_module(empty_module):
...         return empty_module(x)
>>> x = torch.randn(3)
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> # since lazy_stack = False, backprop leaves the original params untouched
>>> y.sum().backward()
>>> assert params["weight"].grad.norm() > 0
>>> assert modules[0].weight.grad is None

With lazy_stack=True, things are slightly different:

>>> params = TensorDict.from_modules(*modules, lazy_stack=True)
>>> print(params)
LazyStackedTensorDict(
    fields={
        bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)
>>> # example of batch execution
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> y.sum().backward()
>>> assert modules[0].weight.grad is not None
classmethod fromkeys(keys: List[Union[str, Tuple[str, ...]]], value: Any = 0)

Creates a tensordict from a list of keys and a single value.

Parameters:
  • keys (list of NestedKey) – An iterable specifying the keys of the new dictionary.

  • value (compatible type, optional) – The value for all keys. Defaults to 0.

gather(dim: int, index: Tensor, out: T | None = None) T

Gathers values along an axis specified by dim.

Parameters:
  • dim (int) – the dimension along which collect the elements

  • index (torch.Tensor) – a long tensor which number of dimension matches the one of the tensordict with only one dimension differring between the two (the gathering dimension). Its elements refer to the index to be gathered along the required dimension.

  • out (TensorDictBase, optional) – a destination tensordict. It must have the same shape as the index.

Examples

>>> td = TensorDict(
...     {"a": torch.randn(3, 4, 5),
...      "b": TensorDict({"c": torch.zeros(3, 4, 5)}, [3, 4, 5])},
...     [3, 4])
>>> index = torch.randint(4, (3, 2))
>>> td_gather = td.gather(dim=1, index=index)
>>> print(td_gather)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 2, 5]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

Gather keeps the dimension names.

Examples

>>> td.names = ["a", "b"]
>>> td_gather = td.gather(dim=1, index=index)
>>> td_gather.names
["a", "b"]
gather_and_stack(dst: int, group: 'dist.ProcessGroup' | None = None) T | None

Gathers tensordicts from various workers and stacks them onto self in the destination worker.

Parameters:
  • dst (int) – the rank of the destination worker where gather_and_stack() will be called.

  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

Example

>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>> import torch
>>>
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     # Create a single tensordict to be sent to server
...     td = TensorDict(
...         {("a", "b"): torch.randn(2),
...          "c": torch.randn(2)}, [2]
...     )
...     td.gather_and_stack(0)
...
>>> def server():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     # Creates the destination tensordict on server.
...     # The first dim must be equal to world_size-1
...     td = TensorDict(
...         {("a", "b"): torch.zeros(2),
...          "c": torch.zeros(2)}, [2]
...     ).expand(1, 2).contiguous()
...     td.gather_and_stack(0)
...     assert td["a", "b"] != 0
...     print("yuppie")
...
>>> if __name__ == "__main__":
...     mp.set_start_method("spawn")
...
...     main_worker = mp.Process(target=server)
...     secondary_worker = mp.Process(target=client)
...
...     main_worker.start()
...     secondary_worker.start()
...
...     main_worker.join()
...     secondary_worker.join()
get(key: ~typing.Union[str, ~typing.Tuple[str, ...]], default: ~typing.Any = <tensordict.base._NoDefault object>) Tensor

Gets the value stored with the input key.

Parameters:
  • key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.

  • default – default value if the key is not found in the tensordict.

Examples

>>> td = TensorDict({"x": 1}, batch_size=[])
>>> td.get("x")
tensor(1)
>>> td.get("y", default=None)
None
get_at(key: ~typing.Union[str, ~typing.Tuple[str, ...]], index: ~typing.Union[None, int, slice, str, ~torch.Tensor, ~typing.List[~typing.Any], ~typing.Tuple[~typing.Any, ...]], default: ~torch.Tensor = <tensordict.base._NoDefault object>) Tensor

Get the value of a tensordict from the key key at the index idx.

Parameters:
  • key (str, tuple of str) – key to be retrieved.

  • index (int, slice, torch.Tensor, iterable) – index of the tensor.

  • default (torch.Tensor) – default value to return if the key is not present in the tensordict.

Returns:

indexed tensor.

Examples

>>> td = TensorDict({"x": torch.arange(3)}, batch_size=[])
>>> td.get_at("x", index=1)
tensor(1)
get_item_shape(key: Union[str, Tuple[str, ...]])

Returns the shape of the entry, possibly avoiding recurring to get().

get_non_tensor(key: ~typing.Union[str, ~typing.Tuple[str, ...]], default=<tensordict.base._NoDefault object>)

Gets a non-tensor value, if it exists, or default if the non-tensor value is not found.

This method is robust to tensor/TensorDict values, meaning that if the value gathered is a regular tensor it will be returned too (although this method comes with some overhead and should not be used out of its natural scope).

See set_non_tensor() for more information on how to set non-tensor values in a tensordict.

Parameters:
  • key (NestedKey) – the location of the NonTensorData object.

  • default (Any, optional) – the value to be returned if the key cannot be found.

Returns: the content of the tensordict.tensorclass.NonTensorData,

or the entry corresponding to the key if it isn’t a tensordict.tensorclass.NonTensorData (or default if the entry cannot be found).

Examples

>>> data = TensorDict({}, batch_size=[])
>>> data.set_non_tensor(("nested", "the string"), "a string!")
>>> assert data.get_non_tensor(("nested", "the string")) == "a string!"
>>> # regular `get` works but returns a NonTensorData object
>>> data.get(("nested", "the string"))
NonTensorData(
    data='a string!',
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
property grad

Returns a tensordict containing the .grad attributes of the leaf tensors.

half()

Casts all tensors to torch.half.

int()

Casts all tensors to torch.int.

irecv(src: int, *, group: 'dist.ProcessGroup' | None = None, return_premature: bool = False, init_tag: int = 0, pseudo_rand: bool = False) tuple[int, list[torch.Future]] | list[torch.Future] | None

Receives the content of a tensordict and updates content with it asynchronously.

Check the example in the isend() method for context.

Parameters:

src (int) – the rank of the source worker.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • return_premature (bool) – if True, returns a list of futures to wait upon until the tensordict is updated. Defaults to False, i.e. waits until update is completed withing the call.

  • init_tag (int) – the init_tag used by the source worker.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to isend(). Defaults to False.

Returns:

if return_premature=True, a list of futures to wait

upon until the tensordict is updated.

is_contiguous() bool

Returns a boolean indicating if all the tensors are contiguous.

is_empty()

Checks if the tensordict contains any leaf.

is_memmap() bool

Checks if tensordict is memory-mapped.

If a TensorDict instance is memory-mapped, it is locked (entries cannot be renamed, removed or added). If a TensorDict is created with tensors that are all memory-mapped, this does __not__ mean that is_memmap will return True (as a new tensor may or may not be memory-mapped). Only if one calls tensordict.memmap_() will the tensordict be considered as memory-mapped.

This is always True for tensordicts on a CUDA device.

is_shared() bool

Checks if tensordict is in shared memory.

If a TensorDict instance is in shared memory, it is locked (entries cannot be renamed, removed or added). If a TensorDict is created with tensors that are all in shared memory, this does __not__ mean that is_shared will return True (as a new tensor may or may not be in shared memory). Only if one calls tensordict.share_memory_() or places the tensordict on a device where the content is shared by default (eg, "cuda") will the tensordict be considered in shared memory.

This is always True for tensordicts on a CUDA device.

isend(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int

Sends the content of the tensordict asynchronously.

Parameters:

dst (int) – the rank of the destination worker where the content should be sent.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to False.

Example

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import multiprocessing as mp
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...
...     td = TensorDict(
...         {
...             ("a", "b"): torch.randn(2),
...             "c": torch.randn(2, 3),
...             "_": torch.ones(2, 1, 5),
...         },
...         [2],
...     )
...     td.isend(0)
...
>>>
>>> def server(queue, return_premature=True):
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.zeros(2),
...             "c": torch.zeros(2, 3),
...             "_": torch.zeros(2, 1, 5),
...         },
...         [2],
...     )
...     out = td.irecv(1, return_premature=return_premature)
...     if return_premature:
...         for fut in out:
...             fut.wait()
...     assert (td != 0).all()
...     queue.put("yuppie")
...
>>>
>>> if __name__ == "__main__":
...     queue = mp.Queue(1)
...     main_worker = mp.Process(
...         target=server,
...         args=(queue, )
...         )
...     secondary_worker = mp.Process(target=client)
...
...     main_worker.start()
...     secondary_worker.start()
...     out = queue.get(timeout=10)
...     assert out == "yuppie"
...     main_worker.join()
...     secondary_worker.join()
items(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) Iterator[tuple[str, CompatibleType]]

Returns a generator of key-value pairs for the tensordict.

Parameters:
  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

keys(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) _TensorDictKeysView

Returns a generator of tensordict keys.

Parameters:
  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

Examples

>>> from tensordict import TensorDict
>>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[])
>>> data.keys()
['0', '1']
>>> list(data.keys(leaves_only=True))
['0']
>>> list(data.keys(include_nested=True, leaves_only=True))
['0', '1', ('1', '2')]
classmethod load(prefix: str | Path) T

Loads a tensordict from disk.

This class method is a proxy to load_memmap().

classmethod load_memmap(prefix: str | Path) T

Loads a memory-mapped tensordict from disk.

Parameters:

prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

This method also allows loading nested tensordicts.

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0
load_state_dict(state_dict: OrderedDict[str, Any], strict=True, assign=False, from_flatten=False) T

Loads a state-dict, formatted as in state_dict(), into the tensordict.

Parameters:
  • state_dict (OrderedDict) – the state_dict of to be copied.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this tensordict’s torch.nn.Module.state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the tensordict instead of copying them inplace into the tensordict’s current tensors. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

  • from_flatten (bool, optional) – if True, the input state_dict is assumed to be flattened. Defaults to False.

Examples

>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
>>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
>>> sd = data.state_dict()
>>> data_zeroed.load_state_dict(sd)
>>> print(data_zeroed["3", "3"])
tensor(3)
>>> # with flattening
>>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
>>> data_zeroed.load_state_dict(data.state_dict(flatten=True), from_flatten=True)
>>> print(data_zeroed["3", "3"])
tensor(3)
lock_() T

Locks a tensordict for non in-place operations.

Functions such as set(), __setitem__(), update(), rename_key_() or other operations that add or remove entries will be blocked.

This method can be used as a decorator.

Example

>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 1, "b": 2, "c": 3}, batch_size=[])
>>> with td.lock_():
...     assert td.is_locked
...     try:
...         td.set("d", 0) # error!
...     except RuntimeError:
...         print("td is locked!")
...     try:
...         del td["d"]
...     except RuntimeError:
...         print("td is locked!")
...     try:
...         td.rename_key_("a", "d")
...     except RuntimeError:
...         print("td is locked!")
...     td.set("a", 0, inplace=True)  # No storage is added, moved or removed
...     td.set_("a", 0) # No storage is added, moved or removed
...     td.update({"a": 0}, inplace=True)  # No storage is added, moved or removed
...     td.update_({"a": 0})  # No storage is added, moved or removed
>>> assert not td.is_locked
map(fn: Callable[[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, *, out: TensorDictBase = None, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, index_with_generator: bool = False, pbar: bool = False, mp_start_method: str | None = None)

Maps a function to splits of the tensordict across one dimension.

This method will apply a function to a tensordict instance by chunking it in tensordicts of equal size and dispatching the operations over the desired number of workers.

The function signature should be Callabe[[TensorDict], Union[TensorDict, Tensor]]. The output must support the torch.cat() operation. The function must be serializable.

Parameters:
  • fn (callable) – function to apply to the tensordict. Signatures similar to Callabe[[TensorDict], Union[TensorDict, Tensor]] are supported.

  • dim (int, optional) – the dim along which the tensordict will be chunked.

  • num_workers (int, optional) – the number of workers. Exclusive with pool. If none is provided, the number of workers will be set to the number of cpus available.

Keyword Arguments:
  • out (TensorDictBase, optional) – an optional container for the output. Its batch-size along the dim provided must match self.ndim. If it is shared or memmap (is_shared() or is_memmap() returns True) it will be populated within the remote processes, avoiding data inward transfers. Otherwise, the data from the self slice will be sent to the process, collected on the current process and written inplace into out.

  • chunksize (int, optional) – The size of each chunk of data. A chunksize of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, whereas chunksize>0 will split the tensordict and call torch.cat() on the resulting list of tensordicts. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with num_chunks.

  • num_chunks (int, optional) – the number of chunks to split the tensordict into. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with chunksize.

  • pool (mp.Pool, optional) – a multiprocess Pool instance to use to execute the job. If none is provided, a pool will be created within the map method.

  • generator (torch.Generator, optional) –

    a generator to use for seeding. A base seed will be generated from it, and each worker of the pool will be seeded with the provided seed incremented by a unique integer from 0 to num_workers. If no generator is provided, a random integer will be used as seed. To work with unseeded workers, a pool should be created separately and passed to map() directly. .. note:

    Caution should be taken when providing a low-valued seed as
    this can cause autocorrelation between experiments, example:
    if 8 workers are asked and the seed is 4, the workers seed will
    range from 4 to 11. If the seed is 5, the workers seed will range
    from 5 to 12. These two experiments will have an overlap of 7
    seeds, which can have unexpected effects on the results.
    

    Note

    The goal of seeding the workers is to have independent seed on each worker, and NOT to have reproducible results across calls of the map method. In other words, two experiments may and probably will return different results as it is impossible to know which worker will pick which job. However, we can make sure that each worker has a different seed and that the pseudo-random operations on each will be uncorrelated.

  • max_tasks_per_child (int, optional) – the maximum number of jobs picked by every child process. Defaults to None, i.e., no restriction on the number of jobs.

  • worker_threads (int, optional) – the number of threads for the workers. Defaults to 1.

  • index_with_generator (bool, optional) – if True, the splitting / chunking of the tensordict will be done during the query, sparing init time. Note that chunk() and split() are much more efficient than indexing (which is used within the generator) so a gain of processing time at init time may have a negative impact on the total runtime. Defaults to False.

  • pbar (bool, optional) – if True, a progress bar will be displayed. Requires tqdm to be available. Defaults to False.

  • mp_start_method (str, optional) – the start method for multiprocessing. If not provided, the default start method will be used. Accepted strings are "fork" and "spawn". Keep in mind that "cuda" tensors cannot be shared between processes with the "fork" start method. This is without effect if the pool is passed to the map method.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>>
>>> def process_data(data):
...     data.set("y", data.get("x") + 1)
...     return data
>>> if __name__ == "__main__":
...     data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_()
...     data = data.map(process_data, dim=1)
...     print(data["y"][:, :10])
...
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

Note

This method is particularily useful when working with large datasets stored on disk (e.g. memory-mapped tensordicts) where chunks will be zero-copied slices of the original data which can be passed to the processes with virtually zero-cost. This allows to tread very large datasets (eg. over a Tb big) to be processed at little cost.

masked_fill(mask: Tensor, value: float | bool) T

Out-of-place version of masked_fill.

Parameters:
  • mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.

  • value – value to used to fill the tensors.

Returns:

self

Examples

>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...     batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td1 = td.masked_fill(mask, 1.0)
>>> td1.get("a")
tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
masked_fill_(mask: Tensor, value: float | int | bool) T

Fills the values corresponding to the mask with the desired value.

Parameters:
  • mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.

  • value – value to used to fill the tensors.

Returns:

self

Examples

>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...     batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td.masked_fill_(mask, 1.0)
>>> td.get("a")
tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
masked_select(mask: Tensor) T

Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values.

Parameters:

mask (torch.Tensor) – boolean mask to be used for the tensors. Shape must match the TensorDict batch_size.

Examples

>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...    batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td_mask = td.masked_select(mask)
>>> td_mask.get("a")
tensor([[0., 0., 0., 0.]])
mean(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the mean value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new tensordict with the tensors stored on disk if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Writes all tensors onto a corresponding memory-mapped Tensor, in-place.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

self if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_like(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Creates a contentless Memory-mapped tensordict with the same shapes as the original one.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new TensorDict instance with data stored as memory-mapped tensors if return_early=False, otherwise a TensorDictFuture instance.

Note

This is the recommended method to write a set of large buffers on disk, as memmap_() will copy the information, which can be slow for large content.

Examples

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
named_apply(fn: Callable, *others: T, nested_keys: bool = False, batch_size: Sequence[int] | None = None, device: torch.device | None = <tensordict.base._NoDefault object>, names: Sequence[str] | None = None, inplace: bool = False, default: Any = <tensordict.base._NoDefault object>, filter_empty: bool | None = None, propagate_lock: bool = False, **constructor_kwargs) T | None

Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict.

The callable signature must be Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]].

Parameters:
  • fn (Callable) – function to be applied to the (name, tensor) pairs in the tensordict. For each leaf, only its leaf name will be used (not the full NestedKey).

  • *others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The fn argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through the default keyword argument.

  • nested_keys (bool, optional) – if True, the complete path to the leaf will be used. Defaults to False, i.e. only the last string is passed to the function.

  • batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The batch_size argument should match the batch_size after the transformation. This is a keyword only argument.

  • device (torch.device, optional) – the resulting device, if any.

  • names (list of str, optional) – the new dimension names, in case the batch_size is modified.

  • inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.

  • default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.

  • filter_empty (bool, optional) – if True, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Defaults to False for backward compatibility.

  • propagate_lock (bool, optional) – if True, a locked tensordict will produce another locked tensordict. Defaults to False.

  • **constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.

Returns:

a new tensordict with transformed_in tensors.

Example

>>> td = TensorDict({
...     "a": -torch.ones(3),
...     "nested": {"a": torch.ones(3), "b": torch.zeros(3)}},
...     batch_size=[3])
>>> def name_filter(name, tensor):
...     if name == "a":
...         return tensor
>>> td.named_apply(name_filter)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> def name_filter(name, *tensors):
...     if name == "a":
...         r = 0
...         for tensor in tensors:
...             r = r + tensor
...         return tensor
>>> out = td.named_apply(name_filter, td)
>>> print(out)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> print(out["a"])
tensor([-1., -1., -1.])

Note

If None is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def name_filter(name, tensor):
...     if name == "1":
...         return tensor
>>> td.named_apply(name_filter)
TensorDict(
    fields={
        1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
property names

The dimension names of the tensordict.

The names can be set at construction time using the names argument.

See also refine_names() for details on how to set the names after construction.

nanmean(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the mean of all non-NaN elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

nansum(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the sum of all non-NaN elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

property ndim: int

See batch_dims().

ndimension() int

See batch_dims().

numel() int

Total number of elements in the batch.

Lower-bounded to 1, as a stack of two tensordict with empty shape will have two elements, therefore we consider that a tensordict is at least 1-element big.

permute(*args, **kwargs)

Returns a view of a tensordict with the batch dimensions permuted according to dims.

Parameters:
  • *dims_list (int) – the new ordering of the batch dims of the tensordict. Alternatively, a single iterable of integers can be provided.

  • dims (list of int) – alternative way of calling permute(…).

Returns:

a new tensordict with the batch dimensions in the desired order.

Examples

>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4])
>>> print(tensordict.permute([1, 0]))
PermutedTensorDict(
    source=TensorDict(
        fields={
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
        device=cpu,
        is_shared=False),
    op=permute(dims=[1, 0]))
>>> print(tensordict.permute(1, 0))
PermutedTensorDict(
    source=TensorDict(
        fields={
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
        device=cpu,
        is_shared=False),
    op=permute(dims=[1, 0]))
>>> print(tensordict.permute(dims=[1, 0]))
PermutedTensorDict(
    source=TensorDict(
        fields={
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
        device=cpu,
        is_shared=False),
    op=permute(dims=[1, 0]))
pin_memory() T

Calls pin_memory() on the stored tensors.

pop(key: ~typing.Union[str, ~typing.Tuple[str, ...]], default: ~typing.Any = <tensordict.base._NoDefault object>) Tensor

Removes and returns a value from a tensordict.

If the value is not present and no default value is provided, a KeyError is thrown.

Parameters:
  • key (str or nested key) – the entry to look for.

  • default (Any, optional) – the value to return if the key cannot be found.

Examples

>>> td = TensorDict({"1": 1}, [])
>>> one = td.pop("1")
>>> assert one == 1
>>> none = td.pop("1", default=None)
>>> assert none is None
popitem() Tuple[Union[str, Tuple[str, ...]], Tensor]

Removes the item that was last inserted into the TensorDict.

popitem will only return non-nested values.

prod(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the produce of values of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the prod value of all leaves (if this can be computed). If integer or tuple of integers, prod is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

recv(src: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int

Receives the content of a tensordict and updates content with it.

Check the example in the send method for context.

Parameters:

src (int) – the rank of the source worker.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the init_tag used by the source worker.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to send(). Defaults to False.

reduce(dst, op=None, async_op=False, return_premature=False, group=None)

Reduces the tensordict across all machines.

Only the process with rank dst is going to receive the final result.

refine_names(*names)

Refines the dimension names of self according to names.

Refining is a special case of renaming that “lifts” unnamed dimensions. A None dim can be refined to have any name; a named dim can only be refined to have the same name.

Because named tensors can coexist with unnamed tensors, refining names gives a nice way to write named-tensor-aware code that works with both named and unnamed tensors.

names may contain up to one Ellipsis (…). The Ellipsis is expanded greedily; it is expanded in-place to fill names to the same length as self.dim() using names from the corresponding indices of self.names.

Returns: the same tensordict with dimensions named according to the input.

Examples

>>> td = TensorDict({}, batch_size=[3, 4, 5, 6])
>>> tdr = td.refine_names(None, None, None, "d")
>>> assert tdr.names == [None, None, None, "d"]
>>> tdr = td.refine_names("a", None, None, "d")
>>> assert tdr.names == ["a", None, None, "d"]
rename(*names, **rename_map)

Returns a clone of the tensordict with dimensions renamed.

Examples

>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4])
>>> td.names = list("abcd")
>>> td_rename = td.rename(c="g")
>>> assert td_rename.names == list("abgd")
rename_(*names, **rename_map)

Same as rename(), but executes the renaming in-place.

Examples

>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4])
>>> td.names = list("abcd")
>>> assert td.rename_(c="g")
>>> assert td.names == list("abgd")
rename_key_(old_key: Union[str, Tuple[str, ...]], new_key: Union[str, Tuple[str, ...]], safe: bool = False) T

Renames a key with a new string and returns the same tensordict with the updated key name.

Parameters:
  • old_key (str or nested key) – key to be renamed.

  • new_key (str or nested key) – new name of the entry.

  • safe (bool, optional) – if True, an error is thrown when the new key is already present in the TensorDict.

Returns:

self

reshape(*args, **kwargs) T

Returns a contiguous, reshaped tensor of the desired shape.

Parameters:

*shape (int) – new shape of the resulting tensordict.

Returns:

A TensorDict with reshaped keys

Examples

>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td = td.reshape(12)
>>> print(td['x'])
torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Saves the tensordict to disk.

This function is a proxy to memmap().

select(*keys: Union[str, Tuple[str, ...]], inplace: bool = False, strict: bool = True) T

Selects the keys of the tensordict and returns a new tensordict with only the selected keys.

The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.

Parameters:
  • *keys (str) – keys to select

  • inplace (bool) – if True, the tensordict is pruned in place. Default is False.

  • strict (bool, optional) – whether selecting a key that is not present will return an error or not. Default: True.

Returns:

A new tensordict (or the same if inplace=True) with the selected keys only.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, [])
>>> td.select("a", ("b", "c"))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> td.select("a", "b")
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> td.select("this key does not exist", strict=False)
TensorDict(
    fields={
    },
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
send(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) None

Sends the content of a tensordict to a distant worker.

Parameters:

dst (int) – the rank of the destination worker where the content should be sent.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to False.

Example

>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>> import torch
>>>
>>>
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...
...     td = TensorDict(
...         {
...             ("a", "b"): torch.randn(2),
...             "c": torch.randn(2, 3),
...             "_": torch.ones(2, 1, 5),
...         },
...         [2],
...     )
...     td.send(0)
...
>>>
>>> def server(queue):
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.zeros(2),
...             "c": torch.zeros(2, 3),
...             "_": torch.zeros(2, 1, 5),
...         },
...         [2],
...     )
...     td.recv(1)
...     assert (td != 0).all()
...     queue.put("yuppie")
...
>>>
>>> if __name__=="__main__":
...     queue = mp.Queue(1)
...     main_worker = mp.Process(target=server, args=(queue,))
...     secondary_worker = mp.Process(target=client)
...
...     main_worker.start()
...     secondary_worker.start()
...     out = queue.get(timeout=10)
...     assert out == "yuppie"
...     main_worker.join()
...     secondary_worker.join()
set(key: Union[str, Tuple[str, ...]], item: Tensor, inplace: bool = False, *, non_blocking: bool = False, **kwargs: Any) T

Sets a new key-value pair.

Parameters:
  • key (str, tuple of str) – name of the key to be set.

  • item (torch.Tensor or equivalent, TensorDictBase instance) – value to be stored in the tensordict.

  • inplace (bool, optional) – if True and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If inplace is True and the entry cannot be found, it will be added. For a more restrictive in-place operation, use set_() instead. Defaults to False.

Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size[3, 4])
>>> td.set("x", torch.randn(3, 4))
>>> y = torch.randn(3, 4, 5)
>>> td.set("y", y, inplace=True) # works, even if 'y' is not present yet
>>> td.set("y", torch.zeros_like(y), inplace=True)
>>> assert (y==0).all() # y values are overwritten
>>> td.set("y", torch.ones(5), inplace=True) # raises an exception as shapes mismatch
set_(key: Union[str, Tuple[str, ...]], item: Tensor, *, non_blocking: bool = False) T

Sets a value to an existing key while keeping the original storage.

Parameters:
Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size[3, 4])
>>> x = torch.randn(3, 4)
>>> td.set("x", x)
>>> td.set_("x", torch.zeros_like(x))
>>> assert (x == 0).all()
set_at_(key: Union[str, Tuple[str, ...]], value: Tensor, index: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], *, non_blocking: bool = False) T

Sets the values in-place at the index indicated by index.

Parameters:
  • key (str, tuple of str) – key to be modified.

  • value (torch.Tensor) – value to be set at the index index

  • index (int, tensor or tuple) – index where to write the values.

Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size[3, 4])
>>> x = torch.randn(3, 4)
>>> td.set("x", x)
>>> td.set_at_("x", value=torch.ones(1, 4), index=slice(1))
>>> assert (x[0] == 1).all()
set_non_tensor(key: Union[str, Tuple[str, ...]], value: Any)

Registers a non-tensor value in the tensordict using tensordict.tensorclass.NonTensorData.

The value can be retrieved using TensorDictBase.get_non_tensor() or directly using get, which will return the tensordict.tensorclass.NonTensorData object.

return: self

Examples

>>> data = TensorDict({}, batch_size=[])
>>> data.set_non_tensor(("nested", "the string"), "a string!")
>>> assert data.get_non_tensor(("nested", "the string")) == "a string!"
>>> # regular `get` works but returns a NonTensorData object
>>> data.get(("nested", "the string"))
NonTensorData(
    data='a string!',
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
setdefault(key: Union[str, Tuple[str, ...]], default: Tensor, inplace: bool = False) Tensor

Insert the key entry with a value of default if key is not in the tensordict.

Return the value for key if key is in the tensordict, else default.

Parameters:
  • key (str or nested key) – the name of the value.

  • default (torch.Tensor or compatible type, TensorDictBase) – value to be stored in the tensordict if the key is not already present.

Returns:

The value of key in the tensordict. Will be default if the key was not previously set.

Examples

>>> td = TensorDict({}, batch_size=[3, 4])
>>> val = td.setdefault("a", torch.zeros(3, 4))
>>> assert (val == 0).all()
>>> val = td.setdefault("a", torch.ones(3, 4))
>>> assert (val == 0).all() # output is still 0
property shape: Size

See batch_size.

share_memory_() T

Places all the tensors in shared memory.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Conversely, once the tensordict is unlocked, the share_memory attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

self

size(dim: int | None = None) torch.Size | int

Returns the size of the dimension indicated by dim.

If dim is not specified, returns the batch_size attribute of the TensorDict.

property sorted_keys: list[NestedKey]

Returns the keys sorted in alphabetical order.

Does not support extra arguments.

If the TensorDict is locked, the keys are cached until the tensordict is unlocked for faster execution.

split(split_size: int | list[int], dim: int = 0) list[TensorDictBase]

Splits each tensor in the TensorDict with the specified size in the given dimension, like torch.split.

Returns a list of TensorDict instances with the view of split chunks of items.

Parameters:
  • split_size (int or List(int)) – size of a single chunk or list of sizes for each chunk.

  • dim (int) – dimension along which to split the tensor.

Returns:

A list of TensorDict with specified size in given dimension.

Examples

>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td0, td1 = td.split([1, 2], dim=0)
>>> print(td0['x'])
torch.Tensor([[0, 1, 2, 3]])
squeeze(*args, **kwargs)

Squeezes all tensors for a dimension in between -self.batch_dims+1 and self.batch_dims-1 and returns them in a new tensordict.

Parameters:

dim (Optional[int]) – dimension along which to squeeze. If dim is None, all singleton dimensions will be squeezed. Defaults to None.

Examples

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 1, 4, 2),
... }, batch_size=[3, 1, 4])
>>> td = td.squeeze()
>>> td.shape
torch.Size([3, 4])
>>> td.get("x").shape
torch.Size([3, 4, 2])

This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary). This functionality is not compatible with implicit squeezing.

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 1, 4, 2),
... }, batch_size=[3, 1, 4])
>>> with td.squeeze(1) as tds:
...     tds.set("y", torch.zeros(3, 4))
>>> assert td.get("y").shape == [3, 1, 4]
state_dict(destination=None, prefix='', keep_vars=False, flatten=False) OrderedDict[str, Any]

Produces a state_dict from the tensordict.

The structure of the state-dict will still be nested, unless flatten is set to True.

A tensordict state-dict contains all the tensors and meta-data needed to rebuild the tensordict (names are currently not supported).

Parameters:
  • destination (dict, optional) – If provided, the state of tensordict will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to tensor names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the torch.Tensor items returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

  • flatten (bool, optional) – whether the structure should be flattened with the "." character or not. Defaults to False.

Examples

>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
>>> sd = data.state_dict()
>>> print(sd)
OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3', OrderedDict([('3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])), ('__batch_size', torch.Size([])), ('__device', None)])
>>> sd = data.state_dict(flatten=True)
OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3.3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])
std(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, correction: int = 1) bool | TensorDictBase

Returns the standard deviation value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, std is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

correction (int) – difference between the sample size and sample degrees of freedom. Defaults to Bessel’s correction, correction=1.

sum(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the sum value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

to(*args, **kwargs: Any) T

Maps a TensorDictBase subclass either on another device, dtype or to another TensorDictBase subclass (if permitted).

Casting tensors to a new dtype is not allowed, as tensordicts are not bound to contain a single tensor dtype.

Parameters:
  • device (torch.device, optional) – the desired device of the tensordict.

  • dtype (torch.dtype, optional) – the desired floating point or complex dtype of the tensordict.

  • tensor (torch.Tensor, optional) – Tensor whose dtype and device are the desired dtype and device for all tensors in this TensorDict.

Keyword Arguments:
  • non_blocking (bool, optional) – whether the operations should be blocking.

  • memory_format (torch.memory_format, optional) – the desired memory format for 4D parameters and buffers in this tensordict.

  • batch_size (torch.Size, optional) – resulting batch-size of the output tensordict.

  • other (TensorDictBase, optional) –

    TensorDict instance whose dtype and device are the desired dtype and device for all tensors in this TensorDict. .. note:: Since TensorDictBase instances do not have

    a dtype, the dtype is gathered from the example leaves. If there are more than one dtype, then no dtype casting is undertook.

Returns:

a new tensordict instance if the device differs from the tensordict device and/or if the dtype is passed. The same tensordict otherwise. batch_size only modifications are done in-place.

Examples

>>> data = TensorDict({"a": 1.0}, [], device=None)
>>> data_cuda = data.to("cuda:0")  # casts to cuda
>>> data_int = data.to(torch.int)  # casts to int
>>> data_cuda_int = data.to("cuda:0", torch.int)  # multiple casting
>>> data_cuda = data.to(torch.randn(3, device="cuda:0"))  # using an example tensor
>>> data_cuda = data.to(other=TensorDict({}, [], device="cuda:0"))  # using a tensordict example
to_dict() dict[str, Any]

Returns a dictionary with key-value pairs matching those of the tensordict.

to_h5(filename, **kwargs)

Converts a tensordict to a PersistentTensorDict with the h5 backend.

Parameters:
  • filename (str or path) – path to the h5 file.

  • device (torch.device or compatible, optional) – the device where to expect the tensor once they are returned. Defaults to None (on cpu by default).

  • **kwargs – kwargs to be passed to h5py.File.create_dataset().

Returns:

A PersitentTensorDict instance linked to the newly created file.

Examples

>>> import tempfile
>>> import timeit
>>>
>>> from tensordict import TensorDict, MemoryMappedTensor
>>> td = TensorDict({
...     "a": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000)),
...     "b": {"c": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000, 3))},
... }, [1_000_000])
>>>
>>> file = tempfile.NamedTemporaryFile()
>>> td_h5 = td.to_h5(file.name, compression="gzip", compression_opts=9)
>>> print(td_h5)
PersistentTensorDict(
    fields={
        a: Tensor(shape=torch.Size([1000000]), device=cpu, dtype=torch.float32, is_shared=False),
        b: PersistentTensorDict(
            fields={
                c: Tensor(shape=torch.Size([1000000, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1000000]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1000000]),
    device=None,
    is_shared=False)
to_module(module: nn.Module, *, inplace: bool | None = None, return_swap: bool = True, swap_dest=None, use_state_dict: bool = False, non_blocking: bool = False, memo=None)

Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.

Parameters:

module (nn.Module) – a module to write the parameters into.

Keyword Arguments:
  • inplace (bool, optional) – if True, the parameters or tensors in the module are updated in-place. Defaults to True.

  • return_swap (bool, optional) – if True, the old parameter configuration will be returned. Defaults to False.

  • swap_dest (TensorDictBase, optional) – if return_swap is True, the tensordict where the swap should be written.

  • use_state_dict (bool, optional) – if True, state-dict API will be used to load the parameters (including the state-dict hooks). Defaults to False.

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Examples

>>> from torch import nn
>>> module = nn.TransformerDecoder(
...     decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
...     num_layers=1)
>>> params = TensorDict.from_module(module)
>>> params.zero_()
>>> params.to_module(module)
>>> assert (module.layers[0].linear1.weight == 0).all()
to_padded_tensor(padding=0.0, mask_key: NestedKey | None = None)

Converts all nested tensors to a padded version and adapts the batch-size accordingly.

Parameters:
  • padding (float) – the padding value for the tensors in the tensordict. Defaults to 0.0.

  • mask_key (NestedKey, optional) – if provided, the key where a mask for valid values will be written. Will result in an error if the heterogeneous dimension isn’t part of the tensordict batch-size. Defaults to None

to_tensordict() T

Returns a regular TensorDict instance from the TensorDictBase.

Returns:

a new TensorDict object containing the same values.

transpose(dim0, dim1)

Returns a tensordict that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

In-place or out-place modifications of the transposed tensordict will impact the original tensordict too as the memory is shared and the operations are mapped back on the original tensordict.

Examples

>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4])
>>> tensordict_transpose = tensordict.transpose(0, 1)
>>> print(tensordict_transpose.shape)
torch.Size([4, 3])
>>> tensordict_transpose.set("b",, torch.randn(4, 3))
>>> print(tensordict.get("b").shape)
torch.Size([3, 4])
type(dst_type)

Casts all tensors to dst_type.

Parameters:

dst_type (type or string) – the desired type

unbind(dim: int) tuple[T, ...]

Returns a tuple of indexed tensordicts, unbound along the indicated dimension.

Examples

>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td0, td1, td2 = td.unbind(0)
>>> td0['x']
tensor([0, 1, 2, 3])
>>> td1['x']
tensor([4, 5, 6, 7])
unflatten(dim, unflattened_size)

Unflattens a tensordict dim expanding it to a desired shape.

Parameters:
  • dim (int) – specifies the dimension of the input tensor to be unflattened.

  • unflattened_size (shape) – is the new shape of the unflattened dimension of the tensordict.

Examples

>>> td = TensorDict({
...     "a": torch.arange(60).view(3, 4, 5),
...     "b": torch.arange(12).view(3, 4)},
...     batch_size=[3, 4])
>>> td_flat = td.flatten(0, 1)
>>> td_unflat = td_flat.unflatten(0, [3, 4])
>>> assert (td == td_unflat).all()
unflatten_keys(separator: str = '.', inplace: bool = False) T

Converts a flat tensordict into a nested one, recursively.

The TensorDict type will be lost and the result will be a simple TensorDict instance. The metadata of the nested tensordicts will be inferred from the root: all instances across the data tree will share the same batch-size, dimension names and device.

Parameters:
  • separator (str, optional) – the separator between the nested items.

  • inplace (bool, optional) – if True, the resulting tensordict will have the same identity as the one where the call has been made. Defaults to False.

Examples

>>> data = TensorDict({"a": 1, "b - c": 2, "e - f - g": 3}, batch_size=[])
>>> data.unflatten_keys(separator=" - ")
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        e: TensorDict(
            fields={
                f: TensorDict(
                    fields={
                        g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This method and unflatten_keys() are particularily useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.

Examples

>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4))
>>> ddp_model = torch.ao.quantization.QuantWrapper(model)
>>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".")
>>> print(state_dict)
TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model_state_dict = state_dict.get("module")
>>> print(model_state_dict)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
unlock_() T

Unlocks a tensordict for non in-place operations.

Can be used as a decorator.

See lock_() for more details.

unsqueeze(*args, **kwargs)

Unsqueezes all tensors for a dimension comprised in between -td.batch_dims and td.batch_dims and returns them in a new tensordict.

Parameters:

dim (int) – dimension along which to unsqueeze

Examples

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> td = td.unsqueeze(-2)
>>> td.shape
torch.Size([3, 1, 4])
>>> td.get("x").shape
torch.Size([3, 1, 4, 2])

This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary).

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> with td.unsqueeze(-2) as tds:
...     tds.set("y", torch.zeros(3, 1, 4))
>>> assert td.get("y").shape == [3, 4]
update(input_dict_or_td: dict[str, CompatibleType] | T, clone: bool = False, inplace: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T

Updates the TensorDict with values from either a dictionary or another TensorDict.

Parameters:
  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to False.

  • inplace (bool, optional) – if True and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If the entry cannot be found, it will be added. Defaults to False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated. This is aimed at avoiding calls to data_dest.update(data_src.select(*keys_to_update)).

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size=[3])
>>> a = torch.randn(3)
>>> b = torch.randn(3, 4)
>>> other_td = TensorDict({"a": a, "b": b}, batch_size=[])
>>> td.update(other_td, inplace=True) # writes "a" and "b" even though they can't be found
>>> assert td['a'] is other_td['a']
>>> other_td = other_td.clone().zero_()
>>> td.update(other_td)
>>> assert td['a'] is not other_td['a']
update_(input_dict_or_td: dict[str, CompatibleType] | T, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T

Updates the TensorDict in-place with values from either a dictionary or another TensorDict.

Unlike update(), this function will throw an error if the key is unknown to self.

Parameters:
  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated. This is aimed at avoiding calls to data_dest.update_(data_src.select(*keys_to_update)).

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> a = torch.randn(3)
>>> b = torch.randn(3, 4)
>>> td = TensorDict({"a": a, "b": b}, batch_size=[3])
>>> other_td = TensorDict({"a": a*0, "b": b*0}, batch_size=[])
>>> td.update_(other_td)
>>> assert td['a'] is not other_td['a']
>>> assert (td['a'] == other_td['a']).all()
>>> assert (td['a'] == 0).all()
update_at_(input_dict_or_td: dict[str, CompatibleType] | T, idx: IndexType, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T

Updates the TensorDict in-place at the specified index with values from either a dictionary or another TensorDict.

Unlike TensorDict.update, this function will throw an error if the key is unknown to the TensorDict.

Parameters:
  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • idx (int, torch.Tensor, iterable, slice) – index of the tensordict where the update should occur.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Default is False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated.

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({
...     'a': torch.zeros(3, 4, 5),
...     'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4])
>>> td.update_at_(
...     TensorDict({
...         'a': torch.ones(1, 4, 5),
...         'b': torch.ones(1, 4, 10)}, batch_size=[1, 4]),
...    slice(1, 2))
TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 10]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=None,
    is_shared=False)
>>> assert (td[1] == 1).all()
values(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) Iterator[tuple[str, CompatibleType]]

Returns a generator representing the values for the tensordict.

Parameters:
  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

var(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, correction: int = 1) bool | TensorDictBase

Returns the variance value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, var is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

correction (int) – difference between the sample size and sample degrees of freedom. Defaults to Bessel’s correction, correction=1.

view(*shape: int, size: list | tuple | torch.Size | None = None)

Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size.

Parameters:
  • *shape (int) – new shape of the resulting tensordict.

  • size – iterable

Returns:

a new tensordict with the desired batch_size.

Examples

>>> td = TensorDict(source={'a': torch.zeros(3,4,5),
...    'b': torch.zeros(3,4,10,1)}, batch_size=torch.Size([3, 4]))
>>> td_view = td.view(12)
>>> print(td_view.get("a").shape)  # torch.Size([12, 5])
>>> print(td_view.get("b").shape)  # torch.Size([12, 10, 1])
>>> td_view = td.view(-1, 4, 3)
>>> print(td_view.get("a").shape)  # torch.Size([1, 4, 3, 5])
>>> print(td_view.get("b").shape)  # torch.Size([1, 4, 3, 10, 1])
where(condition, other, *, out=None, pad=None)

Return a TensorDict of elements selected from either self or other, depending on condition.

Parameters:
  • condition (BoolTensor) – When True (nonzero), yields self, otherwise yields other.

  • other (TensorDictBase or Scalar) – value (if other is a scalar) or values selected at indices where condition is False.

Keyword Arguments:
  • out (TensorDictBase, optional) – the output TensorDictBase instance.

  • pad (scalar, optional) – if provided, missing keys from the source or destination tensordict will be written as torch.where(mask, self, pad) or torch.where(mask, pad, other). Defaults to None, ie missing keys are not tolerated.

zero_() T

Zeros all tensors in the tensordict in-place.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources