Shortcuts

TensorDictParams

class tensordict.TensorDictParams(parameters: TensorDictBase, *, no_convert=False, lock: bool = False)

Holds a TensorDictBase instance full of parameters.

This class exposes the contained parameters to a parent nn.Module such that iterating over the parameters of the module also iterates over the leaves of the tensordict.

Indexing works exactly as the indexing of the wrapped tensordict. The parameter names will be registered within this module using flatten_keys("_")(). Therefore, the result of named_parameters() and the content of the tensordict will differ slightly in term of key names.

Any operation that sets a tensor in the tensordict will be augmented by a torch.nn.Parameter conversion.

Parameters:

parameters (TensorDictBase) – a tensordict to represent as parameters. Values will be converted to parameters unless no_convert=True.

Keyword Arguments:
  • no_convert (bool) – if True, no conversion to nn.Parameter will occur at construction and after (unless the no_convert attribute is changed). If no_convert is True and if non-parameters are present, they will be registered as buffers. Defaults to False.

  • lock (bool) – if True, the tensordict hosted by TensorDictParams will be locked. This can be useful to avoid unwanted modifications, but also restricts the operations that can be done over the object (and can have significant performance impact when unlock_() is required). Defaults to False.

Examples

>>> from torch import nn
>>> from tensordict import TensorDict
>>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4))
>>> params = TensorDict.from_module(module)
>>> params.lock_()
>>> p = TensorDictParams(params)
>>> print(p)
TensorDictParams(params=TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False))
>>> class CustomModule(nn.Module):
...     def __init__(self, params):
...         super().__init__()
...         self.params = params
>>> m = CustomModule(p)
>>> # the wrapper supports assignment and values are turned in Parameter
>>> m.params['other'] = torch.randn(3)
>>> assert isinstance(m.params['other'], nn.Parameter)
add_module(name: str, module: Optional[Module]) None

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Parameters:
  • name (str) – name of the child module. The child module can be accessed from this module using the given name

  • module (Module) – child module to be added to the module.

all(dim: int = None) bool | TensorDictBase

Checks if all values are True/non-null in the tensordict.

Parameters:

dim (int, optional) – if None, returns a boolean indicating whether all tensors return tensor.all() == True If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

any(dim: int = None) bool | TensorDictBase

Checks if any value is True/non-null in the tensordict.

Parameters:

dim (int, optional) – if None, returns a boolean indicating whether all tensors return tensor.any() == True. If integer, all is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

apply(fn: Callable, *others: TensorDictBase, batch_size: Sequence[int] | None = None, device: torch.device | None = <tensordict.base._NoDefault object>, names: Sequence[str] | None = None, inplace: bool = False, default: Any = <tensordict.base._NoDefault object>, filter_empty: bool | None = None, **constructor_kwargs) TensorDictBase | None

Applies a callable to all values stored in the tensordict and sets them in a new tensordict.

The callable signature must be Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]].

Parameters:
  • fn (Callable) – function to be applied to the tensors in the tensordict.

  • *others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The fn argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through the default keyword argument.

Keyword Arguments:
  • batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The batch_size argument should match the batch_size after the transformation. This is a keyword only argument.

  • device (torch.device, optional) – the resulting device, if any.

  • names (list of str, optional) – the new dimension names, in case the batch_size is modified.

  • inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.

  • default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.

  • filter_empty (bool, optional) – if True, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Non-tensor data is considered as a leaf and thereby will be kept in the tensordict even if left untouched by the function. Defaults to False for backward compatibility.

  • propagate_lock (bool, optional) – if True, a locked tensordict will produce another locked tensordict. Defaults to False.

  • **constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.

Returns:

a new tensordict with transformed_in tensors.

Example

>>> td = TensorDict({
...     "a": -torch.ones(3),
...     "b": {"c": torch.ones(3)}},
...     batch_size=[3])
>>> td_1 = td.apply(lambda x: x+1)
>>> assert (td_1["a"] == 0).all()
>>> assert (td_1["b", "c"] == 2).all()
>>> td_2 = td.apply(lambda x, y: x+y, td)
>>> assert (td_2["a"] == -2).all()
>>> assert (td_2["b", "c"] == 2).all()

Note

If None is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def filter(tensor):
...     if tensor == 1:
...         return tensor
>>> td.apply(filter)
TensorDict(
    fields={
        1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Note

The apply method will return an TensorDict instance, regardless of the input type. To keep the same type, one can execute

>>> out = td.clone(False).update(td.apply(...))
apply_(fn: Callable, *others, **kwargs) T

Applies a callable to all values stored in the tensordict and re-writes them in-place.

Parameters:
  • fn (Callable) – function to be applied to the tensors in the tensordict.

  • *others (sequence of TensorDictBase, optional) – the other tensordicts to be used.

Keyword Args: See apply().

Returns:

self or a copy of self with the function applied

auto_batch_size_(batch_dims: int | None = None) T

Sets the maximum batch-size for the tensordict, up to an optional batch_dims.

Parameters:

batch_dims (int, optional) – if provided, the batch-size will be at most batch_dims long.

Returns:

self

Examples

>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[])
>>> td.auto_batch_size_()
>>> print(td.batch_size)
torch.Size([3, 4])
>>> td.auto_batch_size_(batch_dims=1)
>>> print(td.batch_size)
torch.Size([3])
property batch_dims: int

Length of the tensordict batch size.

Returns:

int describing the number of dimensions of the tensordict.

property batch_size: Size

Shape (or batch_size) of a TensorDict.

The shape of a tensordict corresponds to the common first N dimensions of the tensors it contains, where N is an arbitrary number. The TensorDict shape is controlled by the user upon initialization (ie, it is not inferred from the tensor shapes).

The batch_size can be edited dynamically if the new size is compatible with the TensorDict content. For instance, setting the batch size to an empty value is always allowed.

Returns:

a Size object describing the TensorDict batch size.

Examples

>>> data = TensorDict({
...     "key 0": torch.randn(3, 4),
...     "key 1": torch.randn(3, 5),
...     "nested": TensorDict({"key 0": torch.randn(3, 4)}, batch_size=[3, 4])},
...     batch_size=[3])
>>> data.batch_size = () # resets the batch-size to an empty value
bfloat16()

Casts all tensors to torch.bfloat16.

bool()

Casts all tensors to torch.bool.

buffers(recurse: bool = True) Iterator[Tensor]

Return an iterator over module buffers.

Parameters:

recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields:

torch.Tensor – module buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
children() Iterator[Module]

Return an iterator over immediate children modules.

Yields:

Module – a child module

chunk(chunks: int, dim: int = 0) tuple[TensorDictBase, ...]

Splits a tensordict into the specified number of chunks, if possible.

Each chunk is a view of the input tensordict.

Parameters:
  • chunks (int) – number of chunks to return

  • dim (int, optional) – dimension along which to split the tensordict. Default is 0.

Examples

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> td0, td1 = td.chunk(dim=-1, chunks=2)
>>> td0['x']
tensor([[[ 0,  1],
         [ 2,  3]],
        [[ 8,  9],
         [10, 11]],
        [[16, 17],
         [18, 19]]])
clear() T

Erases the content of the tensordict.

clear_device_() T

Clears the device of the tensordict.

Returns: self

clone(recurse: bool = True, **kwargs) T

Clones a TensorDictBase subclass instance onto a new TensorDictBase subclass of the same type.

To create a TensorDict instance from any other TensorDictBase subtype, call the to_tensordict() method instead.

Parameters:

recurse (bool, optional) – if True, each tensor contained in the TensorDict will be copied too. Otherwise only the TensorDict tree structure will be copied. Defaults to True.

Note

Unlike many other ops (pointwise arithmetic, shape operations, …) clone does not inherit the original lock attribute. This design choice is made such that a clone can be created to be modified, which is the most frequent usage.

compile(*args, **kwargs)

Compile this Module’s forward using torch.compile().

This Module’s __call__ method is compiled and all arguments are passed as-is to torch.compile().

See torch.compile() for details on the arguments for this function.

contiguous(*args, **kwargs)

Returns a new tensordict of the same type with contiguous values (or self if values are already contiguous).

copy()

Return a shallow copy of the tensordict (ie, copies the structure but not the data).

Equivalent to TensorDictBase.clone(recurse=False)

copy_(tensordict: T, non_blocking: bool = None) T

See TensorDictBase.update_.

The non-blocking argument will be ignored and is just present for compatibility with torch.Tensor.copy_().

copy_at_(tensordict: T, idx: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]], non_blocking: bool = False) T

See TensorDictBase.update_at_.

cpu()

Casts a tensordict to CPU.

create_nested(key)

Creates a nested tensordict of the same shape, device and dim names as the current tensordict.

If the value already exists, it will be overwritten by this operation. This operation is blocked in locked tensordicts.

Examples

>>> data = TensorDict({}, [3, 4, 5])
>>> data.create_nested("root")
>>> data.create_nested(("some", "nested", "value"))
>>> print(data)
TensorDict(
    fields={
        root: TensorDict(
            fields={
            },
            batch_size=torch.Size([3, 4, 5]),
            device=None,
            is_shared=False),
        some: TensorDict(
            fields={
                nested: TensorDict(
                    fields={
                        value: TensorDict(
                            fields={
                            },
                            batch_size=torch.Size([3, 4, 5]),
                            device=None,
                            is_shared=False)},
                    batch_size=torch.Size([3, 4, 5]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([3, 4, 5]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3, 4, 5]),
    device=None,
    is_shared=False)
cuda(device=None)

Casts a tensordict to a cuda device (if not already on it).

Parameters:

device (int, optional) – if provided, the cuda device on which the tensor should be cast.

property data

Returns a tensordict containing the .data attributes of the leaf tensors.

del_(*args, **kwargs)

Deletes a key of the tensordict.

Parameters:

key (NestedKey) – key to be deleted

Returns:

self

property depth: int

Returns the depth - maximum number of levels - of a tensordict.

The minimum depth is 0 (no nested tensordict).

detach() T

Detach the tensors in the tensordict.

Returns:

a new tensordict with no tensor requiring gradient.

detach_(*args, **kwargs)

Detach the tensors in the tensordict in-place.

Returns:

self.

property device

Device of a TensorDict.

If the TensorDict has a specified device, all its tensors (incl. nested ones) must live on the same device. If the TensorDict device is None, different values can be located on different devices.

Returns:

torch.device object indicating the device where the tensors are placed, or None if TensorDict does not have a device.

Examples

>>> td = TensorDict({
...     "cpu": torch.randn(3, device='cpu'),
...     "cuda": torch.randn(3, device='cuda'),
... }, batch_size=[], device=None)
>>> td['cpu'].device
device(type='cpu')
>>> td['cuda'].device
device(type='cuda')
>>> td = TensorDict({
...     "x": torch.randn(3, device='cpu'),
...     "y": torch.randn(3, device='cuda'),
... }, batch_size=[], device='cuda')
>>> td['x'].device
device(type='cuda')
>>> td['y'].device
device(type='cuda')
>>> td = TensorDict({
...     "x": torch.randn(3, device='cpu'),
...     "y": TensorDict({'z': torch.randn(3, device='cpu')}, batch_size=[], device=None),
... }, batch_size=[], device='cuda')
>>> td['x'].device
device(type='cuda')
>>> td['y'].device # nested tensordicts are also mapped onto the appropriate device.
device(type='cuda')
>>> td['y', 'x'].device
device(type='cuda')
dim() int

See batch_dims().

double()

Casts all tensors to torch.bool.

property dtype

Returns the dtype of the values in the tensordict, if it is unique.

dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Saves the tensordict to disk.

This function is a proxy to memmap().

empty(recurse=False) T

Returns a new, empty tensordict with the same device and batch size.

Parameters:

recurse (bool, optional) – if True, the entire structure of the TensorDict will be reproduced without content. Otherwise, only the root will be duplicated. Defaults to False.

entry_class(*args, **kwargs)

Returns the class of an entry, possibly avoiding a call to isinstance(td.get(key), type).

This method should be preferred to tensordict.get(key).shape whenever get() can be expensive to execute.

eval() T

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

This is equivalent with self.train(False).

See Locally disabling gradient computation for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

self

Return type:

Module

exclude(*keys: Union[str, Tuple[str, ...]], inplace: bool = False) T

Excludes the keys of the tensordict and returns a new tensordict without these entries.

The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.

Parameters:
  • *keys (str) – keys to exclude.

  • inplace (bool) – if True, the tensordict is pruned in place. Default is False.

Returns:

A new tensordict (or the same if inplace=True) without the excluded entries.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, [])
>>> td.exclude("a", ("b", "c"))
TensorDict(
    fields={
        b: TensorDict(
            fields={
                d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> td.exclude("a", "b")
TensorDict(
    fields={
    },
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
expand(*args, **kwargs) T

Expands each tensor of the tensordict according to the expand() function, ignoring the feature dimensions.

Supports iterables to specify the shape.

Examples

>>> td = TensorDict({
...     'a': torch.zeros(3, 4, 5),
...     'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4])
>>> td_expand = td.expand(10, 3, 4)
>>> assert td_expand.shape == torch.Size([10, 3, 4])
>>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5])
extra_repr() str

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

fill_(key: NestedKey, value: float | bool) T

Fills a tensor pointed by the key with a given scalar value.

Parameters:
  • key (str or nested key) – entry to be filled.

  • value (Number or bool) – value to use for the filling.

Returns:

self

filter_non_tensor_data() T

Filters out all non-tensor-data.

flatten(start_dim=0, end_dim=- 1)

Flattens all the tensors of a tensordict.

Parameters:
  • start_dim (int) – the first dim to flatten

  • end_dim (int) – the last dim to flatten

Examples

>>> td = TensorDict({
...     "a": torch.arange(60).view(3, 4, 5),
...     "b": torch.arange(12).view(3, 4)}, batch_size=[3, 4])
>>> td_flat = td.flatten(0, 1)
>>> td_flat.batch_size
torch.Size([12])
>>> td_flat["a"]
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]])
>>> td_flat["b"]
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
flatten_keys(separator: str = '.', inplace: bool = False) TensorDictBase

Converts a nested tensordict into a flat one, recursively.

The TensorDict type will be lost and the result will be a simple TensorDict instance.

Parameters:
  • separator (str, optional) – the separator between the nested items.

  • inplace (bool, optional) – if True, the resulting tensordict will have the same identity as the one where the call has been made. Defaults to False.

  • is_leaf (callable, optional) – a callable over a class type returning a bool indicating if this class has to be considered as a leaf.

Examples

>>> data = TensorDict({"a": 1, ("b", "c"): 2, ("e", "f", "g"): 3}, batch_size=[])
>>> data.flatten_keys(separator=" - ")
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b - c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        e - f - g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This method and unflatten_keys() are particularily useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.

Examples

>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4))
>>> ddp_model = torch.ao.quantization.QuantWrapper(model)
>>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".")
>>> print(state_dict)
TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model_state_dict = state_dict.get("module")
>>> print(model_state_dict)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
float()

Casts all tensors to torch.float.

forward(*input: Any) None

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

from_dict_instance(input_dict, batch_size=None, device=None, batch_dims=None)

Instance method version of from_dict().

Unlike from_dict(), this method will attempt to keep the tensordict types within the existing tree (for any existing leaf).

Examples

>>> from tensordict import TensorDict, tensorclass
>>> import torch
>>>
>>> @tensorclass
>>> class MyClass:
...     x: torch.Tensor
...     y: int
>>>
>>> td = TensorDict({"a": torch.randn(()), "b": MyClass(x=torch.zeros(()), y=1)})
>>> print(td.from_dict_instance(td.to_dict()))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: MyClass(
            x=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
            y=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(td.from_dict(td.to_dict()))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
classmethod from_module(module, as_module: bool = False, lock: bool = True, use_state_dict: bool = False)

Copies the params and buffers of a module in a tensordict.

Parameters:
  • module (nn.Module) – the module to get the parameters from.

  • as_module (bool, optional) – if True, a TensorDictParams instance will be returned which can be used to store parameters within a torch.nn.Module. Defaults to False.

  • lock (bool, optional) – if True, the resulting tensordict will be locked. Defaults to True.

  • use_state_dict (bool, optional) –

    if True, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to False. .. note:

    This is particularly useful when state-dict hooks have to be
    used.
    

Examples

>>> from torch import nn
>>> module = nn.TransformerDecoder(
...     decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
...     num_layers=1)
>>> params = TensorDict.from_module(module)
>>> print(params["layers", "0", "linear1"])
TensorDict(
    fields={
        bias: Parameter(shape=torch.Size([2048]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Parameter(shape=torch.Size([2048, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
classmethod from_modules(*modules, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False)

Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.

Parameters:

modules (sequence of nn.Module) – the modules to get the parameters from. If the modules differ in their structure, a lazy stack is needed (see the lazy_stack argument below).

Keyword Arguments:
  • as_module (bool, optional) – if True, a TensorDictParams instance will be returned which can be used to store parameters within a torch.nn.Module. Defaults to False.

  • lock (bool, optional) – if True, the resulting tensordict will be locked. Defaults to True.

  • use_state_dict (bool, optional) –

    if True, the state-dict from the module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to False. .. note:

    This is particularly useful when state-dict hooks have to be
    used.
    

  • lazy_stack (bool, optional) –

    whether parameters should be densly or lazily stacked. Defaults to False (dense stack).

    Note

    lazy_stack and as_module are exclusive features.

    Warning

    There is a crucial difference between lazy and non-lazy outputs in that non-lazy output will reinstantiate parameters with the desired batch-size, while lazy_stack will just represent the parameters as lazily stacked. This means that whilst the original parameters can safely be passed to an optimizer when lazy_stack=True, the new parameters need to be passed when it is set to True.

    Warning

    Whilst it can be tempting to use a lazy stack to keep the orignal parameter references, remember that lazy stack perform a stack each time get() is called. This will require memory (N times the size of the parameters, more if a graph is built) and time to be computed. It also means that the optimizer(s) will contain more parameters, and operations like step() or zero_grad() will take longer to be executed. In general, lazy_stack should be reserved to very few use cases.

Examples

>>> from torch import nn
>>> from tensordict import TensorDict
>>> torch.manual_seed(0)
>>> empty_module = nn.Linear(3, 4, device="meta")
>>> n_models = 2
>>> modules = [nn.Linear(3, 4) for _ in range(n_models)]
>>> params = TensorDict.from_modules(*modules)
>>> print(params)
TensorDict(
    fields={
        bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
>>> # example of batch execution
>>> def exec_module(params, x):
...     with params.to_module(empty_module):
...         return empty_module(x)
>>> x = torch.randn(3)
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> # since lazy_stack = False, backprop leaves the original params untouched
>>> y.sum().backward()
>>> assert params["weight"].grad.norm() > 0
>>> assert modules[0].weight.grad is None

With lazy_stack=True, things are slightly different:

>>> params = TensorDict.from_modules(*modules, lazy_stack=True)
>>> print(params)
LazyStackedTensorDict(
    fields={
        bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)
>>> # example of batch execution
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> y.sum().backward()
>>> assert modules[0].weight.grad is not None
classmethod fromkeys(keys: List[Union[str, Tuple[str, ...]]], value: Any = 0)

Creates a tensordict from a list of keys and a single value.

Parameters:
  • keys (list of NestedKey) – An iterable specifying the keys of the new dictionary.

  • value (compatible type, optional) – The value for all keys. Defaults to 0.

gather(dim: int, index: Tensor, out: T | None = None) T

Gathers values along an axis specified by dim.

Parameters:
  • dim (int) – the dimension along which collect the elements

  • index (torch.Tensor) – a long tensor which number of dimension matches the one of the tensordict with only one dimension differring between the two (the gathering dimension). Its elements refer to the index to be gathered along the required dimension.

  • out (TensorDictBase, optional) – a destination tensordict. It must have the same shape as the index.

Examples

>>> td = TensorDict(
...     {"a": torch.randn(3, 4, 5),
...      "b": TensorDict({"c": torch.zeros(3, 4, 5)}, [3, 4, 5])},
...     [3, 4])
>>> index = torch.randint(4, (3, 2))
>>> td_gather = td.gather(dim=1, index=index)
>>> print(td_gather)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 2, 5]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

Gather keeps the dimension names.

Examples

>>> td.names = ["a", "b"]
>>> td_gather = td.gather(dim=1, index=index)
>>> td_gather.names
["a", "b"]
gather_and_stack(dst: int, group: 'dist.ProcessGroup' | None = None) T | None

Gathers tensordicts from various workers and stacks them onto self in the destination worker.

Parameters:
  • dst (int) – the rank of the destination worker where gather_and_stack() will be called.

  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

Example

>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>> import torch
>>>
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     # Create a single tensordict to be sent to server
...     td = TensorDict(
...         {("a", "b"): torch.randn(2),
...          "c": torch.randn(2)}, [2]
...     )
...     td.gather_and_stack(0)
...
>>> def server():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     # Creates the destination tensordict on server.
...     # The first dim must be equal to world_size-1
...     td = TensorDict(
...         {("a", "b"): torch.zeros(2),
...          "c": torch.zeros(2)}, [2]
...     ).expand(1, 2).contiguous()
...     td.gather_and_stack(0)
...     assert td["a", "b"] != 0
...     print("yuppie")
...
>>> if __name__ == "__main__":
...     mp.set_start_method("spawn")
...
...     main_worker = mp.Process(target=server)
...     secondary_worker = mp.Process(target=client)
...
...     main_worker.start()
...     secondary_worker.start()
...
...     main_worker.join()
...     secondary_worker.join()
get(key: ~typing.Union[str, ~typing.Tuple[str, ...]], default: ~typing.Any = <tensordict.base._NoDefault object>) Tensor

Gets the value stored with the input key.

Parameters:
  • key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.

  • default – default value if the key is not found in the tensordict.

Examples

>>> td = TensorDict({"x": 1}, batch_size=[])
>>> td.get("x")
tensor(1)
>>> td.get("y", default=None)
None
get_at(key: ~typing.Union[str, ~typing.Tuple[str, ...]], index: ~typing.Union[None, int, slice, str, ~torch.Tensor, ~typing.List[~typing.Any], ~typing.Tuple[~typing.Any, ...]], default: ~torch.Tensor = <tensordict.base._NoDefault object>) Tensor

Get the value of a tensordict from the key key at the index idx.

Parameters:
  • key (str, tuple of str) – key to be retrieved.

  • index (int, slice, torch.Tensor, iterable) – index of the tensor.

  • default (torch.Tensor) – default value to return if the key is not present in the tensordict.

Returns:

indexed tensor.

Examples

>>> td = TensorDict({"x": torch.arange(3)}, batch_size=[])
>>> td.get_at("x", index=1)
tensor(1)
get_buffer(target: str) Tensor

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Parameters:

target – The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

The buffer referenced by target

Return type:

torch.Tensor

Raises:

AttributeError – If the target string references an invalid path or resolves to something that is not a buffer

get_extra_state() Any

Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

Any extra state to store in the module’s state_dict

Return type:

object

get_item_shape(key: Union[str, Tuple[str, ...]])

Returns the shape of the entry, possibly avoiding recurring to get().

get_non_tensor(key: ~typing.Union[str, ~typing.Tuple[str, ...]], default=<tensordict.base._NoDefault object>)

Gets a non-tensor value, if it exists, or default if the non-tensor value is not found.

This method is robust to tensor/TensorDict values, meaning that if the value gathered is a regular tensor it will be returned too (although this method comes with some overhead and should not be used out of its natural scope).

See set_non_tensor() for more information on how to set non-tensor values in a tensordict.

Parameters:
  • key (NestedKey) – the location of the NonTensorData object.

  • default (Any, optional) – the value to be returned if the key cannot be found.

Returns: the content of the tensordict.tensorclass.NonTensorData,

or the entry corresponding to the key if it isn’t a tensordict.tensorclass.NonTensorData (or default if the entry cannot be found).

Examples

>>> data = TensorDict({}, batch_size=[])
>>> data.set_non_tensor(("nested", "the string"), "a string!")
>>> assert data.get_non_tensor(("nested", "the string")) == "a string!"
>>> # regular `get` works but returns a NonTensorData object
>>> data.get(("nested", "the string"))
NonTensorData(
    data='a string!',
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
get_parameter(target: str) Parameter

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Parameters:

target – The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

The Parameter referenced by target

Return type:

torch.nn.Parameter

Raises:

AttributeError – If the target string references an invalid path or resolves to something that is not an nn.Parameter

get_submodule(target: str) Module

Return the submodule given by target if it exists, otherwise throw an error.

For example, let’s say you have an nn.Module A that looks like this:

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Parameters:

target – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)

Returns:

The submodule referenced by target

Return type:

torch.nn.Module

Raises:

AttributeError – If the target string references an invalid path or resolves to something that is not an nn.Module

property grad

Returns a tensordict containing the .grad attributes of the leaf tensors.

half()

Casts all tensors to torch.half.

int()

Casts all tensors to torch.int.

ipu(device: Optional[Union[int, device]] = None) T

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

Note

This method modifies the module in-place.

Parameters:

device (int, optional) – if specified, all parameters will be copied to that device

Returns:

self

Return type:

Module

irecv(src: int, *, group: 'dist.ProcessGroup' | None = None, return_premature: bool = False, init_tag: int = 0, pseudo_rand: bool = False) tuple[int, list[torch.Future]] | list[torch.Future] | None

Receives the content of a tensordict and updates content with it asynchronously.

Check the example in the isend() method for context.

Parameters:

src (int) – the rank of the source worker.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • return_premature (bool) – if True, returns a list of futures to wait upon until the tensordict is updated. Defaults to False, i.e. waits until update is completed withing the call.

  • init_tag (int) – the init_tag used by the source worker.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to isend(). Defaults to False.

Returns:

if return_premature=True, a list of futures to wait

upon until the tensordict is updated.

is_contiguous(*args, **kwargs)

Returns a boolean indicating if all the tensors are contiguous.

is_empty() bool

Checks if the tensordict contains any leaf.

property is_memmap: bool

Checks if tensordict is memory-mapped.

If a TensorDict instance is memory-mapped, it is locked (entries cannot be renamed, removed or added). If a TensorDict is created with tensors that are all memory-mapped, this does __not__ mean that is_memmap will return True (as a new tensor may or may not be memory-mapped). Only if one calls tensordict.memmap_() will the tensordict be considered as memory-mapped.

This is always True for tensordicts on a CUDA device.

property is_shared: bool

Checks if tensordict is in shared memory.

If a TensorDict instance is in shared memory, it is locked (entries cannot be renamed, removed or added). If a TensorDict is created with tensors that are all in shared memory, this does __not__ mean that is_shared will return True (as a new tensor may or may not be in shared memory). Only if one calls tensordict.share_memory_() or places the tensordict on a device where the content is shared by default (eg, "cuda") will the tensordict be considered in shared memory.

This is always True for tensordicts on a CUDA device.

isend(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int

Sends the content of the tensordict asynchronously.

Parameters:

dst (int) – the rank of the destination worker where the content should be sent.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to False.

Example

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import multiprocessing as mp
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...
...     td = TensorDict(
...         {
...             ("a", "b"): torch.randn(2),
...             "c": torch.randn(2, 3),
...             "_": torch.ones(2, 1, 5),
...         },
...         [2],
...     )
...     td.isend(0)
...
>>>
>>> def server(queue, return_premature=True):
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.zeros(2),
...             "c": torch.zeros(2, 3),
...             "_": torch.zeros(2, 1, 5),
...         },
...         [2],
...     )
...     out = td.irecv(1, return_premature=return_premature)
...     if return_premature:
...         for fut in out:
...             fut.wait()
...     assert (td != 0).all()
...     queue.put("yuppie")
...
>>>
>>> if __name__ == "__main__":
...     queue = mp.Queue(1)
...     main_worker = mp.Process(
...         target=server,
...         args=(queue, )
...         )
...     secondary_worker = mp.Process(target=client)
...
...     main_worker.start()
...     secondary_worker.start()
...     out = queue.get(timeout=10)
...     assert out == "yuppie"
...     main_worker.join()
...     secondary_worker.join()
items(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) Iterator[CompatibleType]

Returns a generator of key-value pairs for the tensordict.

Parameters:
  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

keys(*args, **kwargs)

Returns a generator of tensordict keys.

Parameters:
  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

Examples

>>> from tensordict import TensorDict
>>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[])
>>> data.keys()
['0', '1']
>>> list(data.keys(leaves_only=True))
['0']
>>> list(data.keys(include_nested=True, leaves_only=True))
['0', '1', ('1', '2')]
classmethod load(prefix: str | Path) T

Loads a tensordict from disk.

This class method is a proxy to load_memmap().

classmethod load_memmap(prefix: str | Path) T

Loads a memory-mapped tensordict from disk.

Parameters:

prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

This method also allows loading nested tensordicts.

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0
load_state_dict(state_dict: OrderedDict[str, Any], strict=True, assign=False)

Loads a state-dict, formatted as in state_dict(), into the tensordict.

Parameters:
  • state_dict (OrderedDict) – the state_dict of to be copied.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this tensordict’s torch.nn.Module.state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the tensordict instead of copying them inplace into the tensordict’s current tensors. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

  • from_flatten (bool, optional) – if True, the input state_dict is assumed to be flattened. Defaults to False.

Examples

>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
>>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
>>> sd = data.state_dict()
>>> data_zeroed.load_state_dict(sd)
>>> print(data_zeroed["3", "3"])
tensor(3)
>>> # with flattening
>>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, [])
>>> data_zeroed.load_state_dict(data.state_dict(flatten=True), from_flatten=True)
>>> print(data_zeroed["3", "3"])
tensor(3)
lock_() T

Locks a tensordict for non in-place operations.

Functions such as set(), __setitem__(), update(), rename_key_() or other operations that add or remove entries will be blocked.

This method can be used as a decorator.

Example

>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 1, "b": 2, "c": 3}, batch_size=[])
>>> with td.lock_():
...     assert td.is_locked
...     try:
...         td.set("d", 0) # error!
...     except RuntimeError:
...         print("td is locked!")
...     try:
...         del td["d"]
...     except RuntimeError:
...         print("td is locked!")
...     try:
...         td.rename_key_("a", "d")
...     except RuntimeError:
...         print("td is locked!")
...     td.set("a", 0, inplace=True)  # No storage is added, moved or removed
...     td.set_("a", 0) # No storage is added, moved or removed
...     td.update({"a": 0}, inplace=True)  # No storage is added, moved or removed
...     td.update_({"a": 0})  # No storage is added, moved or removed
>>> assert not td.is_locked
map(fn: Callable, dim: int = 0, num_workers: int = None, chunksize: int = None, num_chunks: int = None, pool: mp.Pool = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, mp_start_method: str | None = None)

Maps a function to splits of the tensordict across one dimension.

This method will apply a function to a tensordict instance by chunking it in tensordicts of equal size and dispatching the operations over the desired number of workers.

The function signature should be Callabe[[TensorDict], Union[TensorDict, Tensor]]. The output must support the torch.cat() operation. The function must be serializable.

Parameters:
  • fn (callable) – function to apply to the tensordict. Signatures similar to Callabe[[TensorDict], Union[TensorDict, Tensor]] are supported.

  • dim (int, optional) – the dim along which the tensordict will be chunked.

  • num_workers (int, optional) – the number of workers. Exclusive with pool. If none is provided, the number of workers will be set to the number of cpus available.

Keyword Arguments:
  • out (TensorDictBase, optional) – an optional container for the output. Its batch-size along the dim provided must match self.ndim. If it is shared or memmap (is_shared() or is_memmap() returns True) it will be populated within the remote processes, avoiding data inward transfers. Otherwise, the data from the self slice will be sent to the process, collected on the current process and written inplace into out.

  • chunksize (int, optional) – The size of each chunk of data. A chunksize of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, whereas chunksize>0 will split the tensordict and call torch.cat() on the resulting list of tensordicts. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with num_chunks.

  • num_chunks (int, optional) – the number of chunks to split the tensordict into. If none is provided, the number of chunks will equate the number of workers. For very large tensordicts, such large chunks may not fit in memory for the operation to be done and more chunks may be needed to make the operation practically doable. This argument is exclusive with chunksize.

  • pool (mp.Pool, optional) – a multiprocess Pool instance to use to execute the job. If none is provided, a pool will be created within the map method.

  • generator (torch.Generator, optional) –

    a generator to use for seeding. A base seed will be generated from it, and each worker of the pool will be seeded with the provided seed incremented by a unique integer from 0 to num_workers. If no generator is provided, a random integer will be used as seed. To work with unseeded workers, a pool should be created separately and passed to map() directly. .. note:

    Caution should be taken when providing a low-valued seed as
    this can cause autocorrelation between experiments, example:
    if 8 workers are asked and the seed is 4, the workers seed will
    range from 4 to 11. If the seed is 5, the workers seed will range
    from 5 to 12. These two experiments will have an overlap of 7
    seeds, which can have unexpected effects on the results.
    

    Note

    The goal of seeding the workers is to have independent seed on each worker, and NOT to have reproducible results across calls of the map method. In other words, two experiments may and probably will return different results as it is impossible to know which worker will pick which job. However, we can make sure that each worker has a different seed and that the pseudo-random operations on each will be uncorrelated.

  • max_tasks_per_child (int, optional) – the maximum number of jobs picked by every child process. Defaults to None, i.e., no restriction on the number of jobs.

  • worker_threads (int, optional) – the number of threads for the workers. Defaults to 1.

  • index_with_generator (bool, optional) – if True, the splitting / chunking of the tensordict will be done during the query, sparing init time. Note that chunk() and split() are much more efficient than indexing (which is used within the generator) so a gain of processing time at init time may have a negative impact on the total runtime. Defaults to False.

  • pbar (bool, optional) – if True, a progress bar will be displayed. Requires tqdm to be available. Defaults to False.

  • mp_start_method (str, optional) – the start method for multiprocessing. If not provided, the default start method will be used. Accepted strings are "fork" and "spawn". Keep in mind that "cuda" tensors cannot be shared between processes with the "fork" start method. This is without effect if the pool is passed to the map method.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>>
>>> def process_data(data):
...     data.set("y", data.get("x") + 1)
...     return data
>>> if __name__ == "__main__":
...     data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_()
...     data = data.map(process_data, dim=1)
...     print(data["y"][:, :10])
...
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

Note

This method is particularily useful when working with large datasets stored on disk (e.g. memory-mapped tensordicts) where chunks will be zero-copied slices of the original data which can be passed to the processes with virtually zero-cost. This allows to tread very large datasets (eg. over a Tb big) to be processed at little cost.

masked_fill(*args, **kwargs)

Out-of-place version of masked_fill.

Parameters:
  • mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.

  • value – value to used to fill the tensors.

Returns:

self

Examples

>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...     batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td1 = td.masked_fill(mask, 1.0)
>>> td1.get("a")
tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
masked_fill_(*args, **kwargs)

Fills the values corresponding to the mask with the desired value.

Parameters:
  • mask (boolean torch.Tensor) – mask of values to be filled. Shape must match the tensordict batch-size.

  • value – value to used to fill the tensors.

Returns:

self

Examples

>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...     batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td.masked_fill_(mask, 1.0)
>>> td.get("a")
tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
masked_select(mask: Tensor) T

Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values.

Parameters:

mask (torch.Tensor) – boolean mask to be used for the tensors. Shape must match the TensorDict batch_size.

Examples

>>> td = TensorDict(source={'a': torch.zeros(3, 4)},
...    batch_size=[3])
>>> mask = torch.tensor([True, False, False])
>>> td_mask = td.masked_select(mask)
>>> td_mask.get("a")
tensor([[0., 0., 0., 0.]])
mean(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the mean value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new tensordict with the tensors stored on disk if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_(prefix: str | None = None, copy_existing: bool = False, num_threads: int = 0) TensorDictBase

Writes all tensors onto a corresponding memory-mapped Tensor, in-place.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

self if return_early=False, otherwise a TensorDictFuture instance.

Note

Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.

memmap_like(prefix: str | None = None, copy_existing: bool = False, num_threads: int = 0) T

Creates a contentless Memory-mapped tensordict with the same shapes as the original one.

Parameters:
  • prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.

  • copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If True, any existing Tensor will be copied to the new location.

Keyword Arguments:
  • num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.

  • return_early (bool, optional) – if True and num_threads>0, the method will return a future of the tensordict.

  • share_non_tensor (bool, optional) – if True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to False.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

A new TensorDict instance with data stored as memory-mapped tensors if return_early=False, otherwise a TensorDictFuture instance.

Note

This is the recommended method to write a set of large buffers on disk, as memmap_() will copy the information, which can be slow for large content.

Examples

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
modules() Iterator[Module]

Return an iterator over all modules in the network.

Yields:

Module – a module in the network

Note

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
named_apply(fn: Callable, *others: TensorDictBase, batch_size: Sequence[int] | None = None, device: torch.device | None = <tensordict.base._NoDefault object>, names: Sequence[str] | None = None, inplace: bool = False, default: Any = <tensordict.base._NoDefault object>, filter_empty: bool | None = None, **constructor_kwargs) TensorDictBase | None

Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict.

The callable signature must be Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]].

Parameters:
  • fn (Callable) – function to be applied to the (name, tensor) pairs in the tensordict. For each leaf, only its leaf name will be used (not the full NestedKey).

  • *others (TensorDictBase instances, optional) – if provided, these tensordict instances should have a structure matching the one of self. The fn argument should receive as many unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through the default keyword argument.

  • nested_keys (bool, optional) – if True, the complete path to the leaf will be used. Defaults to False, i.e. only the last string is passed to the function.

  • batch_size (sequence of int, optional) – if provided, the resulting TensorDict will have the desired batch_size. The batch_size argument should match the batch_size after the transformation. This is a keyword only argument.

  • device (torch.device, optional) – the resulting device, if any.

  • names (list of str, optional) – the new dimension names, in case the batch_size is modified.

  • inplace (bool, optional) – if True, changes are made in-place. Default is False. This is a keyword only argument.

  • default (Any, optional) – default value for missing entries in the other tensordicts. If not provided, missing entries will raise a KeyError.

  • filter_empty (bool, optional) – if True, empty tensordicts will be filtered out. This also comes with a lower computational cost as empty data structures won’t be created and destroyed. Defaults to False for backward compatibility.

  • propagate_lock (bool, optional) – if True, a locked tensordict will produce another locked tensordict. Defaults to False.

  • **constructor_kwargs – additional keyword arguments to be passed to the TensorDict constructor.

Returns:

a new tensordict with transformed_in tensors.

Example

>>> td = TensorDict({
...     "a": -torch.ones(3),
...     "nested": {"a": torch.ones(3), "b": torch.zeros(3)}},
...     batch_size=[3])
>>> def name_filter(name, tensor):
...     if name == "a":
...         return tensor
>>> td.named_apply(name_filter)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> def name_filter(name, *tensors):
...     if name == "a":
...         r = 0
...         for tensor in tensors:
...             r = r + tensor
...         return tensor
>>> out = td.named_apply(name_filter, td)
>>> print(out)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> print(out["a"])
tensor([-1., -1., -1.])

Note

If None is returned by the function, the entry is ignored. This can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def name_filter(name, tensor):
...     if name == "1":
...         return tensor
>>> td.named_apply(name_filter)
TensorDict(
    fields={
        1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_children() Iterator[Tuple[str, Module]]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

(str, Module) – Tuple containing a name and child module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo: Optional[Set[Module]] = None, prefix: str = '', remove_duplicate: bool = True)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters:
  • memo – a memo to store the set of modules already added to the result

  • prefix – a prefix that will be added to the name of the module

  • remove_duplicate – whether to remove the duplicated module instances in the result or not

Yields:

(str, Module) – Tuple of name and module

Note

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
property names

The dimension names of the tensordict.

The names can be set at construction time using the names argument.

See also refine_names() for details on how to set the names after construction.

nanmean(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the mean of all non-NaN elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the mean value of all leaves (if this can be computed). If integer or tuple of integers, mean is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

nansum(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the sum of all non-NaN elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

property ndim: int

See batch_dims().

ndimension() int

See batch_dims().

numel() int

Total number of elements in the batch.

Lower-bounded to 1, as a stack of two tensordict with empty shape will have two elements, therefore we consider that a tensordict is at least 1-element big.

parameters(recurse: bool = True) Iterator[Parameter]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Parameters:

recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter – module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
permute(*args, **kwargs)

Returns a view of a tensordict with the batch dimensions permuted according to dims.

Parameters:
  • *dims_list (int) – the new ordering of the batch dims of the tensordict. Alternatively, a single iterable of integers can be provided.

  • dims (list of int) – alternative way of calling permute(…).

Returns:

a new tensordict with the batch dimensions in the desired order.

Examples

>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4])
>>> print(tensordict.permute([1, 0]))
PermutedTensorDict(
    source=TensorDict(
        fields={
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
        device=cpu,
        is_shared=False),
    op=permute(dims=[1, 0]))
>>> print(tensordict.permute(1, 0))
PermutedTensorDict(
    source=TensorDict(
        fields={
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
        device=cpu,
        is_shared=False),
    op=permute(dims=[1, 0]))
>>> print(tensordict.permute(dims=[1, 0]))
PermutedTensorDict(
    source=TensorDict(
        fields={
            a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32)},
        batch_size=torch.Size([3, 4]),
        device=cpu,
        is_shared=False),
    op=permute(dims=[1, 0]))
pin_memory(*args, **kwargs)

Calls pin_memory() on the stored tensors.

pop(key: ~typing.Union[str, ~typing.Tuple[str, ...]], default: ~typing.Any = <tensordict.base._NoDefault object>) Tensor

Removes and returns a value from a tensordict.

If the value is not present and no default value is provided, a KeyError is thrown.

Parameters:
  • key (str or nested key) – the entry to look for.

  • default (Any, optional) – the value to return if the key cannot be found.

Examples

>>> td = TensorDict({"1": 1}, [])
>>> one = td.pop("1")
>>> assert one == 1
>>> none = td.pop("1", default=None)
>>> assert none is None
popitem()

Removes the item that was last inserted into the TensorDict.

popitem will only return non-nested values.

prod(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the produce of values of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the prod value of all leaves (if this can be computed). If integer or tuple of integers, prod is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

recv(src: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) int

Receives the content of a tensordict and updates content with it.

Check the example in the send method for context.

Parameters:

src (int) – the rank of the source worker.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the init_tag used by the source worker.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. This value must match the one passed to send(). Defaults to False.

reduce(dst, op=None, async_op=False, return_premature=False, group=None)

Reduces the tensordict across all machines.

Only the process with rank dst is going to receive the final result.

refine_names(*names)

Refines the dimension names of self according to names.

Refining is a special case of renaming that “lifts” unnamed dimensions. A None dim can be refined to have any name; a named dim can only be refined to have the same name.

Because named tensors can coexist with unnamed tensors, refining names gives a nice way to write named-tensor-aware code that works with both named and unnamed tensors.

names may contain up to one Ellipsis (…). The Ellipsis is expanded greedily; it is expanded in-place to fill names to the same length as self.dim() using names from the corresponding indices of self.names.

Returns: the same tensordict with dimensions named according to the input.

Examples

>>> td = TensorDict({}, batch_size=[3, 4, 5, 6])
>>> tdr = td.refine_names(None, None, None, "d")
>>> assert tdr.names == [None, None, None, "d"]
>>> tdr = td.refine_names("a", None, None, "d")
>>> assert tdr.names == ["a", None, None, "d"]
register_backward_hook(hook: Callable[[Module, Union[Tuple[Tensor, ...], Tensor], Union[Tuple[Tensor, ...], Tensor]], Union[None, Tuple[Tensor, ...], Tensor]]) RemovableHandle

Register a backward hook on the module.

This function is deprecated in favor of register_full_backward_hook() and the behavior of this function will change in future versions.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict.

Buffers can be accessed as attributes using given names.

Parameters:
  • name (str) – name of the buffer. The buffer can be accessed from this module using the given name

  • tensor (Tensor or None) – buffer to be registered. If None, then operations that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict.

  • persistent (bool) – whether the buffer is part of this module’s state_dict.

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook: Union[Callable[[T, Tuple[Any, ...], Any], Optional[Any]], Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle

Register a forward hook on the module.

The hook will be called every time after forward() has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature:

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:

hook(module, args, kwargs, output) -> None or modified output
Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided hook will be fired before all existing forward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.modules.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False

  • with_kwargs (bool) – If True, the hook will be passed the kwargs given to the forward function. Default: False

  • always_call (bool) – If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_forward_pre_hook(hook: Union[Callable[[T, Tuple[Any, ...]], Optional[Any]], Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle

Register a forward pre-hook on the module.

The hook will be called every time before forward() is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If true, the provided hook will be fired before all existing forward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward_pre hooks on this torch.nn.modules.Module. Note that global forward_pre hooks registered with register_module_forward_pre_hook() will fire before all hooks registered by this method. Default: False

  • with_kwargs (bool) – If true, the hook will be passed the kwargs given to the forward function. Default: False

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_full_backward_hook(hook: Callable[[Module, Union[Tuple[Tensor, ...], Tensor], Union[Tuple[Tensor, ...], Tensor]], Union[None, Tuple[Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:
  • hook (Callable) – The user-defined hook to be registered.

  • prepend (bool) – If true, the provided hook will be fired before all existing backward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.modules.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_full_backward_pre_hook(hook: Callable[[Module, Union[Tuple[Tensor, ...], Tensor]], Union[None, Tuple[Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature:

hook(module, grad_output) -> tuple[Tensor] or None

The grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of grad_output in subsequent computations. Entries in grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:
  • hook (Callable) – The user-defined hook to be registered.

  • prepend (bool) – If true, the provided hook will be fired before all existing backward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.modules.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_get_post_hook(hook)

Register a hook to be called after any get operation on leaf tensors.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module’s load_state_dict is called.

It should have the following signature::

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_module(name: str, module: Optional[Module]) None

Alias for add_module().

register_parameter(name: str, param: Optional[Parameter]) None

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters:
  • name (str) – name of the parameter. The parameter can be accessed from this module using the given name

  • param (Parameter or None) – parameter to be added to the module. If None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.

rename(*names, **rename_map)

Returns a clone of the tensordict with dimensions renamed.

Examples

>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4])
>>> td.names = list("abcd")
>>> td_rename = td.rename(c="g")
>>> assert td_rename.names == list("abgd")
rename_(*names, **rename_map)

Same as rename(), but executes the renaming in-place.

Examples

>>> td = TensorDict({}, batch_size=[1, 2, 3 ,4])
>>> td.names = list("abcd")
>>> assert td.rename_(c="g")
>>> assert td.names == list("abgd")
rename_key_(old_key: Union[str, Tuple[str, ...]], new_key: Union[str, Tuple[str, ...]], safe: bool = False) TensorDictBase

Renames a key with a new string and returns the same tensordict with the updated key name.

Parameters:
  • old_key (str or nested key) – key to be renamed.

  • new_key (str or nested key) – new name of the entry.

  • safe (bool, optional) – if True, an error is thrown when the new key is already present in the TensorDict.

Returns:

self

requires_grad_(requires_grad: bool = True) T

Change if autograd should record operations on parameters in this module.

This method sets the parameters’ requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See Locally disabling gradient computation for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Parameters:

requires_grad (bool) – whether autograd should record operations on parameters in this module. Default: True.

Returns:

self

Return type:

Module

reshape(*shape: int)

Returns a contiguous, reshaped tensor of the desired shape.

Parameters:

*shape (int) – new shape of the resulting tensordict.

Returns:

A TensorDict with reshaped keys

Examples

>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td = td.reshape(12)
>>> print(td['x'])
torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

Saves the tensordict to disk.

This function is a proxy to memmap().

select(*keys: Union[str, Tuple[str, ...]], inplace: bool = False, strict: bool = True) T

Selects the keys of the tensordict and returns a new tensordict with only the selected keys.

The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both tensordicts.

Parameters:
  • *keys (str) – keys to select

  • inplace (bool) – if True, the tensordict is pruned in place. Default is False.

  • strict (bool, optional) – whether selecting a key that is not present will return an error or not. Default: True.

Returns:

A new tensordict (or the same if inplace=True) with the selected keys only.

Examples

>>> from tensordict import TensorDict
>>> td = TensorDict({"a": 0, "b": {"c": 1, "d": 2}}, [])
>>> td.select("a", ("b", "c"))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> td.select("a", "b")
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> td.select("this key does not exist", strict=False)
TensorDict(
    fields={
    },
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
send(dst: int, *, group: 'dist.ProcessGroup' | None = None, init_tag: int = 0, pseudo_rand: bool = False) None

Sends the content of a tensordict to a distant worker.

Parameters:

dst (int) – the rank of the destination worker where the content should be sent.

Keyword Arguments:
  • group (torch.distributed.ProcessGroup, optional) – if set, the specified process group will be used for communication. Otherwise, the default process group will be used. Defaults to None.

  • init_tag (int) – the initial tag to be used to mark the tensors. Note that this will be incremented by as much as the number of tensors contained in the TensorDict.

  • pseudo_rand (bool) – if True, the sequence of tags will be pseudo- random, allowing to send multiple data from different nodes without overlap. Notice that the generation of these pseudo-random numbers is expensive (1e-5 sec/number), meaning that it could slow down the runtime of your algorithm. Defaults to False.

Example

>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>> import torch
>>>
>>>
>>> def client():
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=1,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...
...     td = TensorDict(
...         {
...             ("a", "b"): torch.randn(2),
...             "c": torch.randn(2, 3),
...             "_": torch.ones(2, 1, 5),
...         },
...         [2],
...     )
...     td.send(0)
...
>>>
>>> def server(queue):
...     torch.distributed.init_process_group(
...         "gloo",
...         rank=0,
...         world_size=2,
...         init_method=f"tcp://localhost:10003",
...     )
...     td = TensorDict(
...         {
...             ("a", "b"): torch.zeros(2),
...             "c": torch.zeros(2, 3),
...             "_": torch.zeros(2, 1, 5),
...         },
...         [2],
...     )
...     td.recv(1)
...     assert (td != 0).all()
...     queue.put("yuppie")
...
>>>
>>> if __name__=="__main__":
...     queue = mp.Queue(1)
...     main_worker = mp.Process(target=server, args=(queue,))
...     secondary_worker = mp.Process(target=client)
...
...     main_worker.start()
...     secondary_worker.start()
...     out = queue.get(timeout=10)
...     assert out == "yuppie"
...     main_worker.join()
...     secondary_worker.join()
set(key: Union[str, Tuple[str, ...]], item: Tensor, inplace: bool = False, **kwargs: Any) TensorDictBase

Sets a new key-value pair.

Parameters:
  • key (str, tuple of str) – name of the key to be set.

  • item (torch.Tensor or equivalent, TensorDictBase instance) – value to be stored in the tensordict.

  • inplace (bool, optional) – if True and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If inplace is True and the entry cannot be found, it will be added. For a more restrictive in-place operation, use set_() instead. Defaults to False.

Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size[3, 4])
>>> td.set("x", torch.randn(3, 4))
>>> y = torch.randn(3, 4, 5)
>>> td.set("y", y, inplace=True) # works, even if 'y' is not present yet
>>> td.set("y", torch.zeros_like(y), inplace=True)
>>> assert (y==0).all() # y values are overwritten
>>> td.set("y", torch.ones(5), inplace=True) # raises an exception as shapes mismatch
set_(key: Union[str, Tuple[str, ...]], item: Tensor) T

Sets a value to an existing key while keeping the original storage.

Parameters:
Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size[3, 4])
>>> x = torch.randn(3, 4)
>>> td.set("x", x)
>>> td.set_("x", torch.zeros_like(x))
>>> assert (x == 0).all()
set_at_(key: Union[str, Tuple[str, ...]], value: Tensor, index: Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]) T

Sets the values in-place at the index indicated by index.

Parameters:
  • key (str, tuple of str) – key to be modified.

  • value (torch.Tensor) – value to be set at the index index

  • index (int, tensor or tuple) – index where to write the values.

Keyword Arguments:

non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size[3, 4])
>>> x = torch.randn(3, 4)
>>> td.set("x", x)
>>> td.set_at_("x", value=torch.ones(1, 4), index=slice(1))
>>> assert (x[0] == 1).all()
set_extra_state(state: Any) None

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

Parameters:

state (dict) – Extra state from the state_dict

set_non_tensor(key: Union[str, Tuple[str, ...]], value: Any)

Registers a non-tensor value in the tensordict using tensordict.tensorclass.NonTensorData.

The value can be retrieved using TensorDictBase.get_non_tensor() or directly using get, which will return the tensordict.tensorclass.NonTensorData object.

return: self

Examples

>>> data = TensorDict({}, batch_size=[])
>>> data.set_non_tensor(("nested", "the string"), "a string!")
>>> assert data.get_non_tensor(("nested", "the string")) == "a string!"
>>> # regular `get` works but returns a NonTensorData object
>>> data.get(("nested", "the string"))
NonTensorData(
    data='a string!',
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
setdefault(key: Union[str, Tuple[str, ...]], default: Tensor, inplace: bool = False) Tensor

Insert the key entry with a value of default if key is not in the tensordict.

Return the value for key if key is in the tensordict, else default.

Parameters:
  • key (str or nested key) – the name of the value.

  • default (torch.Tensor or compatible type, TensorDictBase) – value to be stored in the tensordict if the key is not already present.

Returns:

The value of key in the tensordict. Will be default if the key was not previously set.

Examples

>>> td = TensorDict({}, batch_size=[3, 4])
>>> val = td.setdefault("a", torch.zeros(3, 4))
>>> assert (val == 0).all()
>>> val = td.setdefault("a", torch.ones(3, 4))
>>> assert (val == 0).all() # output is still 0
property shape: Size

See batch_size.

share_memory() T

See torch.Tensor.share_memory_().

share_memory_(*args, **kwargs)

Places all the tensors in shared memory.

The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Conversely, once the tensordict is unlocked, the share_memory attribute is turned to False, because cross-process identity is not guaranteed anymore.

Returns:

self

size(dim: int | None = None) torch.Size | int

Returns the size of the dimension indicated by dim.

If dim is not specified, returns the batch_size attribute of the TensorDict.

property sorted_keys: list[NestedKey]

Returns the keys sorted in alphabetical order.

Does not support extra arguments.

If the TensorDict is locked, the keys are cached until the tensordict is unlocked for faster execution.

split(split_size: int | list[int], dim: int = 0) list[TensorDictBase]

Splits each tensor in the TensorDict with the specified size in the given dimension, like torch.split.

Returns a list of TensorDict instances with the view of split chunks of items.

Parameters:
  • split_size (int or List(int)) – size of a single chunk or list of sizes for each chunk.

  • dim (int) – dimension along which to split the tensor.

Returns:

A list of TensorDict with specified size in given dimension.

Examples

>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td0, td1 = td.split([1, 2], dim=0)
>>> print(td0['x'])
torch.Tensor([[0, 1, 2, 3]])
squeeze(*args, **kwargs)

Squeezes all tensors for a dimension in between -self.batch_dims+1 and self.batch_dims-1 and returns them in a new tensordict.

Parameters:

dim (Optional[int]) – dimension along which to squeeze. If dim is None, all singleton dimensions will be squeezed. Defaults to None.

Examples

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 1, 4, 2),
... }, batch_size=[3, 1, 4])
>>> td = td.squeeze()
>>> td.shape
torch.Size([3, 4])
>>> td.get("x").shape
torch.Size([3, 4, 2])

This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary). This functionality is not compatible with implicit squeezing.

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 1, 4, 2),
... }, batch_size=[3, 1, 4])
>>> with td.squeeze(1) as tds:
...     tds.set("y", torch.zeros(3, 4))
>>> assert td.get("y").shape == [3, 1, 4]
state_dict(*args, destination=None, prefix='', keep_vars=False, flatten=True)

Produces a state_dict from the tensordict.

The structure of the state-dict will still be nested, unless flatten is set to True.

A tensordict state-dict contains all the tensors and meta-data needed to rebuild the tensordict (names are currently not supported).

Parameters:
  • destination (dict, optional) – If provided, the state of tensordict will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to tensor names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the torch.Tensor items returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

  • flatten (bool, optional) – whether the structure should be flattened with the "." character or not. Defaults to False.

Examples

>>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
>>> sd = data.state_dict()
>>> print(sd)
OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3', OrderedDict([('3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])), ('__batch_size', torch.Size([])), ('__device', None)])
>>> sd = data.state_dict(flatten=True)
OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3.3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])
std(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, correction: int = 1) bool | TensorDictBase

Returns the standard deviation value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, std is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

correction (int) – difference between the sample size and sample degrees of freedom. Defaults to Bessel’s correction, correction=1.

sum(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, dtype: torch.dtype | None = None) bool | TensorDictBase

Returns the sum value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, sum is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

dtype (torch.dtype, optional) – If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

to(*args, **kwargs) TensorDictBase

Maps a TensorDictBase subclass either on another device, dtype or to another TensorDictBase subclass (if permitted).

Casting tensors to a new dtype is not allowed, as tensordicts are not bound to contain a single tensor dtype.

Parameters:
  • device (torch.device, optional) – the desired device of the tensordict.

  • dtype (torch.dtype, optional) – the desired floating point or complex dtype of the tensordict.

  • tensor (torch.Tensor, optional) – Tensor whose dtype and device are the desired dtype and device for all tensors in this TensorDict.

Keyword Arguments:
  • non_blocking (bool, optional) – whether the operations should be blocking.

  • memory_format (torch.memory_format, optional) – the desired memory format for 4D parameters and buffers in this tensordict.

  • batch_size (torch.Size, optional) – resulting batch-size of the output tensordict.

  • other (TensorDictBase, optional) –

    TensorDict instance whose dtype and device are the desired dtype and device for all tensors in this TensorDict. .. note:: Since TensorDictBase instances do not have

    a dtype, the dtype is gathered from the example leaves. If there are more than one dtype, then no dtype casting is undertook.

Returns:

a new tensordict instance if the device differs from the tensordict device and/or if the dtype is passed. The same tensordict otherwise. batch_size only modifications are done in-place.

Examples

>>> data = TensorDict({"a": 1.0}, [], device=None)
>>> data_cuda = data.to("cuda:0")  # casts to cuda
>>> data_int = data.to(torch.int)  # casts to int
>>> data_cuda_int = data.to("cuda:0", torch.int)  # multiple casting
>>> data_cuda = data.to(torch.randn(3, device="cuda:0"))  # using an example tensor
>>> data_cuda = data.to(other=TensorDict({}, [], device="cuda:0"))  # using a tensordict example
to_dict() dict[str, Any]

Returns a dictionary with key-value pairs matching those of the tensordict.

to_empty(*, device: Optional[Union[int, str, device]], recurse: bool = True) T

Move the parameters and buffers to the specified device without copying storage.

Parameters:
  • device (torch.device) – The desired device of the parameters and buffers in this module.

  • recurse (bool) – Whether parameters and buffers of submodules should be recursively moved to the specified device.

Returns:

self

Return type:

Module

to_h5(filename, **kwargs)

Converts a tensordict to a PersistentTensorDict with the h5 backend.

Parameters:
  • filename (str or path) – path to the h5 file.

  • device (torch.device or compatible, optional) – the device where to expect the tensor once they are returned. Defaults to None (on cpu by default).

  • **kwargs – kwargs to be passed to h5py.File.create_dataset().

Returns:

A PersitentTensorDict instance linked to the newly created file.

Examples

>>> import tempfile
>>> import timeit
>>>
>>> from tensordict import TensorDict, MemoryMappedTensor
>>> td = TensorDict({
...     "a": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000)),
...     "b": {"c": MemoryMappedTensor.from_tensor(torch.zeros(()).expand(1_000_000, 3))},
... }, [1_000_000])
>>>
>>> file = tempfile.NamedTemporaryFile()
>>> td_h5 = td.to_h5(file.name, compression="gzip", compression_opts=9)
>>> print(td_h5)
PersistentTensorDict(
    fields={
        a: Tensor(shape=torch.Size([1000000]), device=cpu, dtype=torch.float32, is_shared=False),
        b: PersistentTensorDict(
            fields={
                c: Tensor(shape=torch.Size([1000000, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([1000000]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1000000]),
    device=None,
    is_shared=False)
to_module(module: nn.Module, *, inplace: bool | None = None, return_swap: bool = True, swap_dest=None, use_state_dict: bool = False, non_blocking: bool = False, memo=None)

Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.

Parameters:

module (nn.Module) – a module to write the parameters into.

Keyword Arguments:
  • inplace (bool, optional) – if True, the parameters or tensors in the module are updated in-place. Defaults to True.

  • return_swap (bool, optional) – if True, the old parameter configuration will be returned. Defaults to False.

  • swap_dest (TensorDictBase, optional) – if return_swap is True, the tensordict where the swap should be written.

  • use_state_dict (bool, optional) – if True, state-dict API will be used to load the parameters (including the state-dict hooks). Defaults to False.

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Examples

>>> from torch import nn
>>> module = nn.TransformerDecoder(
...     decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
...     num_layers=1)
>>> params = TensorDict.from_module(module)
>>> params.zero_()
>>> params.to_module(module)
>>> assert (module.layers[0].linear1.weight == 0).all()
to_padded_tensor(padding=0.0, mask_key: NestedKey | None = None)

Converts all nested tensors to a padded version and adapts the batch-size accordingly.

Parameters:
  • padding (float) – the padding value for the tensors in the tensordict. Defaults to 0.0.

  • mask_key (NestedKey, optional) – if provided, the key where a mask for valid values will be written. Will result in an error if the heterogeneous dimension isn’t part of the tensordict batch-size. Defaults to None

to_tensordict()

Returns a regular TensorDict instance from the TensorDictBase.

Returns:

a new TensorDict object containing the same values.

train(mode: bool = True) T

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Parameters:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

Returns:

self

Return type:

Module

transpose(dim0, dim1)

Returns a tensordict that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

In-place or out-place modifications of the transposed tensordict will impact the original tensordict too as the memory is shared and the operations are mapped back on the original tensordict.

Examples

>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4])
>>> tensordict_transpose = tensordict.transpose(0, 1)
>>> print(tensordict_transpose.shape)
torch.Size([4, 3])
>>> tensordict_transpose.set("b",, torch.randn(4, 3))
>>> print(tensordict.get("b").shape)
torch.Size([3, 4])
type(dst_type)

Casts all tensors to dst_type.

Parameters:

dst_type (type or string) – the desired type

unbind(dim: int) tuple[T, ...]

Returns a tuple of indexed tensordicts, unbound along the indicated dimension.

Examples

>>> td = TensorDict({
...     'x': torch.arange(12).reshape(3, 4),
... }, batch_size=[3, 4])
>>> td0, td1, td2 = td.unbind(0)
>>> td0['x']
tensor([0, 1, 2, 3])
>>> td1['x']
tensor([4, 5, 6, 7])
unflatten(dim, unflattened_size)

Unflattens a tensordict dim expanding it to a desired shape.

Parameters:
  • dim (int) – specifies the dimension of the input tensor to be unflattened.

  • unflattened_size (shape) – is the new shape of the unflattened dimension of the tensordict.

Examples

>>> td = TensorDict({
...     "a": torch.arange(60).view(3, 4, 5),
...     "b": torch.arange(12).view(3, 4)},
...     batch_size=[3, 4])
>>> td_flat = td.flatten(0, 1)
>>> td_unflat = td_flat.unflatten(0, [3, 4])
>>> assert (td == td_unflat).all()
unflatten_keys(separator: str = '.', inplace: bool = False) TensorDictBase

Converts a flat tensordict into a nested one, recursively.

The TensorDict type will be lost and the result will be a simple TensorDict instance. The metadata of the nested tensordicts will be inferred from the root: all instances across the data tree will share the same batch-size, dimension names and device.

Parameters:
  • separator (str, optional) – the separator between the nested items.

  • inplace (bool, optional) – if True, the resulting tensordict will have the same identity as the one where the call has been made. Defaults to False.

Examples

>>> data = TensorDict({"a": 1, "b - c": 2, "e - f - g": 3}, batch_size=[])
>>> data.unflatten_keys(separator=" - ")
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        e: TensorDict(
            fields={
                f: TensorDict(
                    fields={
                        g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This method and unflatten_keys() are particularily useful when handling state-dicts, as they make it possible to seamlessly convert flat dictionaries into data structures that mimic the structure of the model.

Examples

>>> model = torch.nn.Sequential(torch.nn.Linear(3 ,4))
>>> ddp_model = torch.ao.quantization.QuantWrapper(model)
>>> state_dict = TensorDict(ddp_model.state_dict(), batch_size=[]).unflatten_keys(".")
>>> print(state_dict)
TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model_state_dict = state_dict.get("module")
>>> print(model_state_dict)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> model.load_state_dict(dict(model_state_dict.flatten_keys(".")))
unlock_() T

Unlocks a tensordict for non in-place operations.

Can be used as a decorator.

See lock_() for more details.

unsqueeze(*args, **kwargs)

Unsqueezes all tensors for a dimension comprised in between -td.batch_dims and td.batch_dims and returns them in a new tensordict.

Parameters:

dim (int) – dimension along which to unsqueeze

Examples

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> td = td.unsqueeze(-2)
>>> td.shape
torch.Size([3, 1, 4])
>>> td.get("x").shape
torch.Size([3, 1, 4, 2])

This operation can be used as a context manager too. Changes to the original tensordict will occur out-place, i.e. the content of the original tensors will not be altered. This also assumes that the tensordict is not locked (otherwise, unlocking the tensordict is necessary).

>>> td = TensorDict({
...     'x': torch.arange(24).reshape(3, 4, 2),
... }, batch_size=[3, 4])
>>> with td.unsqueeze(-2) as tds:
...     tds.set("y", torch.zeros(3, 1, 4))
>>> assert td.get("y").shape == [3, 4]
update(input_dict_or_td: dict[str, CompatibleType] | TensorDictBase, clone: bool = False, inplace: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) TensorDictBase

Updates the TensorDict with values from either a dictionary or another TensorDict.

Parameters:
  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to False.

  • inplace (bool, optional) – if True and if a key matches an existing key in the tensordict, then the update will occur in-place for that key-value pair. If the entry cannot be found, it will be added. Defaults to False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated. This is aimed at avoiding calls to data_dest.update(data_src.select(*keys_to_update)).

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({}, batch_size=[3])
>>> a = torch.randn(3)
>>> b = torch.randn(3, 4)
>>> other_td = TensorDict({"a": a, "b": b}, batch_size=[])
>>> td.update(other_td, inplace=True) # writes "a" and "b" even though they can't be found
>>> assert td['a'] is other_td['a']
>>> other_td = other_td.clone().zero_()
>>> td.update(other_td)
>>> assert td['a'] is not other_td['a']
update_(input_dict_or_td: dict[str, CompatibleType] | T, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T

Updates the TensorDict in-place with values from either a dictionary or another TensorDict.

Unlike update(), this function will throw an error if the key is unknown to self.

Parameters:
  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Defaults to False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated. This is aimed at avoiding calls to data_dest.update_(data_src.select(*keys_to_update)).

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> a = torch.randn(3)
>>> b = torch.randn(3, 4)
>>> td = TensorDict({"a": a, "b": b}, batch_size=[3])
>>> other_td = TensorDict({"a": a*0, "b": b*0}, batch_size=[])
>>> td.update_(other_td)
>>> assert td['a'] is not other_td['a']
>>> assert (td['a'] == other_td['a']).all()
>>> assert (td['a'] == 0).all()
update_at_(input_dict_or_td: dict[str, CompatibleType] | T, idx: IndexType, clone: bool = False, *, non_blocking: bool = False, keys_to_update: Sequence[NestedKey] | None = None) T

Updates the TensorDict in-place at the specified index with values from either a dictionary or another TensorDict.

Unlike TensorDict.update, this function will throw an error if the key is unknown to the TensorDict.

Parameters:
  • input_dict_or_td (TensorDictBase or dict) – input data to be written in self.

  • idx (int, torch.Tensor, iterable, slice) – index of the tensordict where the update should occur.

  • clone (bool, optional) – whether the tensors in the input ( tensor) dict should be cloned before being set. Default is False.

Keyword Arguments:
  • keys_to_update (sequence of NestedKeys, optional) – if provided, only the list of keys in key_to_update will be updated.

  • non_blocking (bool, optional) – if True and this copy is between different devices, the copy may occur asynchronously with respect to the host.

Returns:

self

Examples

>>> td = TensorDict({
...     'a': torch.zeros(3, 4, 5),
...     'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4])
>>> td.update_at_(
...     TensorDict({
...         'a': torch.ones(1, 4, 5),
...         'b': torch.ones(1, 4, 10)}, batch_size=[1, 4]),
...    slice(1, 2))
TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 10]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=None,
    is_shared=False)
>>> assert (td[1] == 1).all()
values(include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None) Iterator[CompatibleType]

Returns a generator representing the values for the tensordict.

Parameters:
  • include_nested (bool, optional) – if True, nested values will be returned. Defaults to False.

  • leaves_only (bool, optional) – if False, only leaves will be returned. Defaults to False.

  • is_leaf – an optional callable that indicates if a class is to be considered a leaf or not.

var(dim: int | Tuple[int] = <tensordict.base._NoDefault object>, keepdim: bool = <tensordict.base._NoDefault object>, *, correction: int = 1) bool | TensorDictBase

Returns the variance value of all elements in the input tensordit.

Parameters:
  • dim (int, tuple of int, optional) – if None, returns a dimensionless tensordict containing the sum value of all leaves (if this can be computed). If integer or tuple of integers, var is called upon the dimension specified if and only if this dimension is compatible with the tensordict shape.

  • keepdim (bool) –

Keyword Arguments:

correction (int) – difference between the sample size and sample degrees of freedom. Defaults to Bessel’s correction, correction=1.

view(*shape: int, size: list | tuple | torch.Size | None = None)

Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size.

Parameters:
  • *shape (int) – new shape of the resulting tensordict.

  • size – iterable

Returns:

a new tensordict with the desired batch_size.

Examples

>>> td = TensorDict(source={'a': torch.zeros(3,4,5),
...    'b': torch.zeros(3,4,10,1)}, batch_size=torch.Size([3, 4]))
>>> td_view = td.view(12)
>>> print(td_view.get("a").shape)  # torch.Size([12, 5])
>>> print(td_view.get("b").shape)  # torch.Size([12, 10, 1])
>>> td_view = td.view(-1, 4, 3)
>>> print(td_view.get("a").shape)  # torch.Size([1, 4, 3, 5])
>>> print(td_view.get("b").shape)  # torch.Size([1, 4, 3, 10, 1])
where(condition, other, *, out=None, pad=None)

Return a TensorDict of elements selected from either self or other, depending on condition.

Parameters:
  • condition (BoolTensor) – When True (nonzero), yields self, otherwise yields other.

  • other (TensorDictBase or Scalar) – value (if other is a scalar) or values selected at indices where condition is False.

Keyword Arguments:
  • out (TensorDictBase, optional) – the output TensorDictBase instance.

  • pad (scalar, optional) – if provided, missing keys from the source or destination tensordict will be written as torch.where(mask, self, pad) or torch.where(mask, pad, other). Defaults to None, ie missing keys are not tolerated.

xpu(device: Optional[Union[int, device]] = None) T

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

Note

This method modifies the module in-place.

Parameters:

device (int, optional) – if specified, all parameters will be copied to that device

Returns:

self

Return type:

Module

zero_() T

Zeros all tensors in the tensordict in-place.

zero_grad(set_to_none: bool = True) None

Reset gradients of all model parameters.

See similar function under torch.optim.Optimizer for more context.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. See torch.optim.Optimizer.zero_grad() for details.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources