Shortcuts

from_dataclass

class tensordict.from_dataclass(obj: Any, *, dest_cls: Optional[Type] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, device: Optional[device] = None)

Converts a dataclass instance or a type into a tensorclass instance or type, respectively.

This function takes a dataclass instance or a dataclass type and converts it into a tensor-compatible class, optionally applying various configurations such as auto-batching, immutability, and type casting.

Parameters:

obj (Any) – The dataclass instance or type to be converted. If a type is provided, a new class is returned.

Keyword Arguments:
  • dest_cls (tensorclass, optional) – A tensorclass type to be used to map the data. If not provided, a new class is created. Without effect if obj is a type.

  • auto_batch_size (bool, optional) – If True, automatically determines and applies batch size to the resulting object. Defaults to False.

  • batch_dims (int, optional) – If auto_batch_size is True, defines how many dimensions the output tensordict should have. Defaults to None (full batch-size at each level).

  • batch_size (torch.Size, optional) – The batch size of the TensorDict. Defaults to None.

  • frozen (bool, optional) – If True, the resulting class or instance will be immutable. Defaults to False.

  • autocast (bool, optional) – If True, enables automatic type casting for the resulting class or instance. Defaults to False.

  • nocast (bool, optional) – If True, disables any type casting for the resulting class or instance. Defaults to False.

  • inplace (bool, optional) – If True, the dataclass type passed will be modified in-place. Defaults to False. Without effect if an instance is provided.

  • device (torch.device, optional) – The device on which the TensorDict will be created. Defaults to None.

  • shadow (bool, optional) – Disables the validation of field names against TensorDict’s reserved attributes. Use with caution, as this may cause unintended consequences. Defaults to False.

Returns:

A tensor-compatible class or instance derived from the provided dataclass.

Raises:

TypeError – If the provided input is not a dataclass instance or type.

Examples

>>> from dataclasses import dataclass
>>> import torch
>>> from tensordict.tensorclass import from_dataclass
>>>
>>> @dataclass
>>> class X:
...     a: int
...     b: torch.Tensor
...
>>> x = X(0, 0)
>>> x2 = from_dataclass(x)
>>> print(x2)
X(
    a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> X2 = from_dataclass(X, autocast=True)
>>> print(X2(a=0, b=0))
X(
    a=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Warning

Whereas from_dataclass() will return a TensorDict instance by default, this method will return a tensorclass instance or type.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources