ProbabilisticTensorDictSequential
- class tensordict.nn.ProbabilisticTensorDictSequential(*args, **kwargs)
A sequence of
TensorDictModules
containing at least oneProbabilisticTensorDictModule
.This class extends
TensorDictSequential
and is typically configured with a sequence of modules where the final module is an instance ofProbabilisticTensorDictModule
. However, it also supports configurations where one or more intermediate modules are instances ofProbabilisticTensorDictModule
, while the last module may or may not be probabilistic. In all cases, it exposes theget_dist()
method to recover the distribution object from theProbabilisticTensorDictModule
instances in the sequence.Multiple probabilistic modules can co-exist in a single
ProbabilisticTensorDictSequential
. If return_composite isFalse
(default), only the last one will produce a distribution and the others will be executed as regularTensorDictModule
instances. However, if a ProbabilisticTensorDictModule is not the last module in the sequence and return_composite=False, a ValueError will be raised when trying to query the module. If return_composite=True, all intermediate ProbabilisticTensorDictModule instances will contribute to a singleCompositeDistribution
instance.Resulting log-probabilities will be conditional probabilities if samples are interdependent: whenever
then the log-probability of Z will be
- Parameters:
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule) – An ordered sequence of
TensorDictModule
instances, usually terminating in aProbabilisticTensorDictModule
, to be run sequentially. The modules can be instances of TensorDictModuleBase or any other function that matches this signature. Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, and thus will not affect the in_keys and out_keys attributes of the TensorDictSequential.- Keyword Arguments:
partial_tolerant (bool, optional) – If
True
, the input tensordict can miss some of the input keys. If so, only the modules that can be executed given the keys that are present will be executed. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant isTrue
AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. Defaults toFalse
.return_composite (bool, optional) –
If True and multiple
ProbabilisticTensorDictModule
orProbabilisticTensorDictSequential
instances are found, aCompositeDistribution
instance will be used. Otherwise, only the last module will be used to build the distribution. Defaults toFalse
.Warning
The behaviour of
return_composite
will change in v0.9 and default to True from there on.
- Raises:
ValueError – If the input sequence of modules is empty.
TypeError – If the final module is not an instance of
ProbabilisticTensorDictModule
orProbabilisticTensorDictSequential
.
Examples
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq >>> import torch >>> # Typical usage: a single distribution is computed last in the sequence >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq, ... TensorDictModule as Mod >>> torch.manual_seed(0) >>> >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions are ignored when return_composite=False >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=False, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> print(module.log_prob(td)) tensor([-0.9189, -0.9189, -0.9189]) >>> # Intermediate distributions produce a CompositeDistribution when return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]), ... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... return_composite=True, ... ) >>> input = TensorDict(x=torch.ones(3)) >>> td = module(input.copy()) >>> print(td) TensorDict( fields={ loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when >>> # return_composite=True >>> module = Seq( ... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]), ... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal, ... distribution_kwargs={"scale": 1}), ... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]), ... return_composite=True, ... ) >>> td = module(TensorDict(x=torch.ones(3))) >>> print(td) TensorDict( fields={ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(module.get_dist(input)) CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))}) >>> print(module.log_prob(td)) TensorDict( fields={ sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- build_dist_from_params(tensordict: TensorDictBase) Distribution
Constructs a distribution from the input parameters without evaluating other modules in the sequence.
This method searches for the last
ProbabilisticTensorDictModule
in the sequence and uses it to build the distribution.- Parameters:
tensordict (TensorDictBase) – The input tensordict containing the distribution parameters.
- Returns:
The constructed distribution object.
- Return type:
D.Distribution
- Raises:
RuntimeError – If no
ProbabilisticTensorDictModule
is found in the sequence.
- property default_interaction_type
Returns the default_interaction_type of the module using an iterative heuristic.
This property iterates over all modules in reverse order, attempting to retrieve the default_interaction_type attribute from any child module. The first non-None value encountered is returned. If no such value is found, a default interaction_type() is returned.
- property dist_params_keys: List[NestedKey]
Returns all the keys pointing at the distribution params.
- property dist_sample_keys: List[NestedKey]
Returns all the keys pointing at the distribution samples.
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs) TensorDictBase
When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.
- get_dist(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) Distribution
Returns the distribution resulting from passing the input tensordict through the sequence.
If return_composite is
False
(default), this method will only consider the last probabilistic module in the sequence.Otherwise, it will return a
CompositeDistribution
instance containing the distributions of all probabilistic modules.- Parameters:
tensordict (TensorDictBase) – The input tensordict.
tensordict_out (TensorDictBase, optional) – The output tensordict. If
None
, a new tensordict will be created. Defaults toNone
.
- Keyword Arguments:
**kwargs – Additional keyword arguments passed to the underlying modules.
- Returns:
The resulting distribution object.
- Return type:
D.Distribution
- Raises:
RuntimeError – If no probabilistic module is found in the sequence.
Note
When return_composite is
True
, the distributions are conditioned on the previous samples in the sequence. This means that if a module depends on the output of a previous probabilistic module, its distribution will be conditional.
- get_dist_params(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) tuple[torch.distributions.distribution.Distribution, tensordict.base.TensorDictBase]
Returns the distribution parameters and output tensordict.
This method runs the deterministic part of the
ProbabilisticTensorDictSequential
module to obtain the distribution parameters. The interaction type is set to the current global interaction type if available, otherwise it defaults to the interaction type of the last module.- Parameters:
tensordict (TensorDictBase) – The input tensordict.
tensordict_out (TensorDictBase, optional) – The output tensordict. If
None
, a new tensordict will be created. Defaults toNone
.
- Keyword Arguments:
**kwargs – Additional keyword arguments passed to the deterministic part of the module.
- Returns:
A tuple containing the distribution object and the output tensordict.
- Return type:
tuple[D.Distribution, TensorDictBase]
Note
The interaction type is temporarily set to the specified value during the execution of this method.
- log_prob(tensordict, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs)
Returns the log-probability of the input tensordict.
If self.return_composite is
True
and the distribution is aCompositeDistribution
, this method will return the log-probability of the entire composite distribution.Otherwise, it will only consider the last probabilistic module in the sequence.
- Parameters:
tensordict (TensorDictBase) – The input tensordict.
tensordict_out (TensorDictBase, optional) – The output tensordict. If
None
, a new tensordict will be created. Defaults toNone
.
- Keyword Arguments:
dist (torch.distributions.Distribution, optional) – The distribution object. If
None
, it will be computed using get_dist. Defaults toNone
.- Returns:
The log-probability of the input tensordict.
- Return type:
Warning
In future releases (v0.9), the default values of aggregate_probabilities, inplace, and include_sum will change. To avoid warnings, it is recommended to explicitly pass these arguments to the log_prob method or set them in the constructor.