Shortcuts

from_dict

class tensordict.from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None)

Returns a TensorDict created from a dictionary or another TensorDict.

If batch_size is not specified, returns the maximum batch size possible.

This function works on nested dictionaries too, or can be used to determine the batch-size of a nested tensordict.

Parameters:
  • input_dict (dictionary, optional) – a dictionary to use as a data source (nested keys compatible).

  • batch_size (iterable of int, optional) – a batch size for the tensordict.

  • device (torch.device or compatible type, optional) – a device for the TensorDict.

  • batch_dims (int, optional) – the batch_dims (ie number of leading dimensions to be considered for batch_size). Exclusinve with batch_size. Note that this is the __maximum__ number of batch dims of the tensordict, a smaller number is tolerated.

  • names (list of str, optional) – the dimension names of the tensordict.

Examples

>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
>>> print(from_dict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # nested dict: the nested TensorDict can have a different batch-size
>>> # as long as its leading dims match.
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
>>> print(from_dict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # we can also use this to work out the batch sie of a tensordict
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
>>> print(
from_dict(input_td))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources