from_dataclass
- class tensordict.from_dataclass(obj: Any, *, dest_cls: Optional[Type] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, device: Optional[device] = None)
Converts a dataclass instance or a type into a tensorclass instance or type, respectively.
This function takes a dataclass instance or a dataclass type and converts it into a tensor-compatible class, optionally applying various configurations such as auto-batching, immutability, and type casting.
- Parameters:
obj (Any) – The dataclass instance or type to be converted. If a type is provided, a new class is returned.
- Keyword Arguments:
dest_cls (tensorclass, optional) – A tensorclass type to be used to map the data. If not provided, a new class is created. Without effect if
obj
is a type.auto_batch_size (bool, optional) – If
True
, automatically determines and applies batch size to the resulting object. Defaults toFalse
.batch_dims (int, optional) – If auto_batch_size is
True
, defines how many dimensions the output tensordict should have. Defaults toNone
(full batch-size at each level).batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to
None
.frozen (bool, optional) – If
True
, the resulting class or instance will be immutable. Defaults toFalse
.autocast (bool, optional) – If
True
, enables automatic type casting for the resulting class or instance. Defaults toFalse
.nocast (bool, optional) – If
True
, disables any type casting for the resulting class or instance. Defaults toFalse
.inplace (bool, optional) – If
True
, the dataclass type passed will be modified in-place. Defaults toFalse
. Without effect if an instance is provided.device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to
None
.shadow (bool, optional) – Disables the validation of field names against TensorDict’s reserved attributes. Use with caution, as this may cause unintended consequences. Defaults to False.
- Returns:
A tensor-compatible class or instance derived from the provided dataclass.
- Raises:
TypeError – If the provided input is not a dataclass instance or type.
Examples
>>> from dataclasses import dataclass >>> import torch >>> from tensordict.tensorclass import from_dataclass >>> >>> @dataclass >>> class X: ... a: int ... b: torch.Tensor ... >>> x = X(0, 0) >>> x2 = from_dataclass(x) >>> print(x2) X( a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) >>> X2 = from_dataclass(X, autocast=True) >>> print(X2(a=0, b=0)) X( a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)
Warning
Whereas
from_dataclass()
will return aTensorDict
instance by default, this method will return a tensorclass instance or type.