tensordict.nn.TensorDictParams¶
- class tensordict.nn.TensorDictParams(parameters: TensorDictBase, *, no_convert=False, lock: bool = False)¶
Holds a TensorDictBase instance full of parameters.
This class exposes the contained parameters to a parent nn.Module such that iterating over the parameters of the module also iterates over the leaves of the tensordict.
Indexing works exactly as the indexing of the wrapped tensordict. The parameter names will be registered within this module using
flatten_keys("_")()
. Therefore, the result ofnamed_parameters()
and the content of the tensordict will differ slightly in term of key names.Any operation that sets a tensor in the tensordict will be augmented by a
torch.nn.Parameter
conversion.- Parameters:
parameters (TensorDictBase) – a tensordict to represent as parameters. Values will be converted to parameters unless
no_convert=True
.- Keyword Arguments:
no_convert (bool) – if
True
, no conversion tonn.Parameter
will occur at construction and after (unless theno_convert
attribute is changed). Ifno_convert
isTrue
and if non-parameters are present, they will be registered as buffers. Defaults toFalse
.lock (bool) – if
True
, the tensordict hosted by TensorDictParams will be locked. This can be useful to avoid unwanted modifications, but also restricts the operations that can be done over the object (and can have significant performance impact when unlock_() is required). Defaults toFalse
.
Examples
>>> from torch import nn >>> from tensordict import TensorDict >>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)) >>> params = TensorDict.from_module(module) >>> params.lock_() >>> p = TensorDictParams(params) >>> print(p) TensorDictParams(params=TensorDict( fields={ 0: TensorDict( fields={ bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 1: TensorDict( fields={ bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([4, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)) >>> class CustomModule(nn.Module): ... def __init__(self, params): ... super().__init__() ... self.params = params >>> m = CustomModule(p) >>> # the wrapper supports assignment and values are turned in Parameter >>> m.params['other'] = torch.randn(3) >>> assert isinstance(m.params['other'], nn.Parameter)