Shortcuts

tensorclass

class tensordict.tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False)

A decorator to create tensorclass classes.

tensorclass classes are specialized dataclasses.dataclass() instances that can execute some pre-defined tensor operations out of the box, such as indexing, item assignment, reshaping, casting to device or storage and many others.

Parameters:
  • autocast (bool, optional) – if True, the types indicated will be enforced when an argument is set. Defaults to False.

  • frozen (bool, optional) – if True, the content of the tensorclass cannot be modified. This argument is provided to dataclass-compatibility, a similar behavior can be obtained through the lock argument in the class constructor. Defaults to False.

tensorclass can be used with or without arguments: .. rubric:: Examples

>>> @tensorclass
... class X:
...     y: torch.Tensor
>>> X(1).y
1
>>> @tensorclass(autocast=False)
... class X:
...     y: torch.Tensor
>>> X(1).y
1
>>> @tensorclass(autocast=True)
... class X:
...     y: torch.Tensor
>>> X(1).y
torch.tensor(1)

Examples

>>> from tensordict import tensorclass
>>> import torch
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     X: torch.Tensor
...     y: torch.Tensor
...     z: str
...     def expand_and_mask(self):
...         X = self.X.unsqueeze(-1).expand_as(self.y)
...         X = X[self.y]
...         return X
...
>>> data = MyData(
...     X=torch.ones(3, 4, 1),
...     y=torch.zeros(3, 4, 2, 2, dtype=torch.bool),
...     z="test"
...     batch_size=[3, 4])
>>> print(data)
MyData(
    X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
    y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool),
    z="test"
    batch_size=[3, 4],
    device=None,
    is_shared=False)
>>> print(data.expand_and_mask())
tensor([])
It is also possible to nest tensorclasses instances within each other:

Examples: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # although the data is stored as a TensorDict, the type hint helps us >>> # to appropriately cast the data to the right type >>> assert isinstance(nesting_data.nested, type(data))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources