TensorClass
- class tensordict.TensorClass
TensorClass is the inheritance-based version of the @tensorclass decorator.
TensorClass allows you to code dataclasses that are better type-checked and more pythonic than those built with the @tensorclass decorator.
Examples
>>> from typing import Any >>> import torch >>> from tensordict import TensorClass >>> class Foo(TensorClass): ... tensor: torch.Tensor ... non_tensor: Any ... nested: Any = None >>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3]) >>> print(foo) Foo( non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None), tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested=None, batch_size=torch.Size([3]), device=None, is_shared=False)
You can pass keyword arguments in two ways: using brackets or keyword arguments.
Examples
>>> class Foo(TensorClass["autocast"]): ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass, autocast=True): # equivalent ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass["nocast"]): ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass["nocast", "frozen"]): # multiple keywords can be used ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass, nocast=True): # equivalent ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass): ... integer: int >>> Foo(integer=1).integer tensor(1)
Warning
TensorClass itself is not decorated as a tensorclass, but subclasses will be. This is because we cannot anticipate if the frozen argument will be set, and if it is, it may conflict with the parent class (a subclass cannot be frozen if the parent class isn’t).