tensorclass¶
- class tensordict.tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False)¶
A decorator to create
tensorclass
classes.tensorclass
classes are specializeddataclasses.dataclass()
instances that can execute some pre-defined tensor operations out of the box, such as indexing, item assignment, reshaping, casting to device or storage and many others.- Parameters:
autocast (bool, optional) – if
True
, the types indicated will be enforced when an argument is set. Defaults toFalse
.frozen (bool, optional) – if
True
, the content of the tensorclass cannot be modified. This argument is provided to dataclass-compatibility, a similar behavior can be obtained through the lock argument in the class constructor. Defaults toFalse
.
tensorclass can be used with or without arguments: .. rubric:: Examples
>>> @tensorclass ... class X: ... y: torch.Tensor >>> X(1).y 1 >>> @tensorclass(autocast=False) ... class X: ... y: torch.Tensor >>> X(1).y 1 >>> @tensorclass(autocast=True) ... class X: ... y: torch.Tensor >>> X(1).y torch.tensor(1)
Examples
>>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass ... class MyData: ... X: torch.Tensor ... y: torch.Tensor ... z: str ... def expand_and_mask(self): ... X = self.X.unsqueeze(-1).expand_as(self.y) ... X = X[self.y] ... return X ... >>> data = MyData( ... X=torch.ones(3, 4, 1), ... y=torch.zeros(3, 4, 2, 2, dtype=torch.bool), ... z="test" ... batch_size=[3, 4]) >>> print(data) MyData( X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32), y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool), z="test" batch_size=[3, 4], device=None, is_shared=False) >>> print(data.expand_and_mask()) tensor([])
- It is also possible to nest tensorclasses instances within each other:
Examples: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # although the data is stored as a TensorDict, the type hint helps us >>> # to appropriately cast the data to the right type >>> assert isinstance(nesting_data.nested, type(data))