Shortcuts

tensordict.nn.TensorDictParams

class tensordict.nn.TensorDictParams(parameters: TensorDictBase, *, no_convert=False, lock: bool = False)

Holds a TensorDictBase instance full of parameters.

This class exposes the contained parameters to a parent nn.Module such that iterating over the parameters of the module also iterates over the leaves of the tensordict.

Indexing works exactly as the indexing of the wrapped tensordict. The parameter names will be registered within this module using flatten_keys("_")(). Therefore, the result of named_parameters() and the content of the tensordict will differ slightly in term of key names.

Any operation that sets a tensor in the tensordict will be augmented by a torch.nn.Parameter conversion.

Parameters:

parameters (TensorDictBase) – a tensordict to represent as parameters. Values will be converted to parameters unless no_convert=True.

Keyword Arguments:
  • no_convert (bool) – if True, no conversion to nn.Parameter will occur at construction and after (unless the no_convert attribute is changed). If no_convert is True and if non-parameters are present, they will be registered as buffers. Defaults to False.

  • lock (bool) – if True, the tensordict hosted by TensorDictParams will be locked. This can be useful to avoid unwanted modifications, but also restricts the operations that can be done over the object (and can have significant performance impact when unlock_() is required). Defaults to False.

Examples

>>> from torch import nn
>>> from tensordict import TensorDict
>>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4))
>>> params = TensorDict.from_module(module)
>>> params.lock_()
>>> p = TensorDictParams(params)
>>> print(p)
TensorDictParams(params=TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False))
>>> class CustomModule(nn.Module):
...     def __init__(self, params):
...         super().__init__()
...         self.params = params
>>> m = CustomModule(p)
>>> # the wrapper supports assignment and values are turned in Parameter
>>> m.params['other'] = torch.randn(3)
>>> assert isinstance(m.params['other'], nn.Parameter)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources