Shortcuts

ProbabilisticTensorDictSequential

class tensordict.nn.ProbabilisticTensorDictSequential(*args, **kwargs)

A sequence of TensorDictModules containing at least one ProbabilisticTensorDictModule.

This class extends TensorDictSequential and is typically configured with a sequence of modules where the final module is an instance of ProbabilisticTensorDictModule. However, it also supports configurations where one or more intermediate modules are instances of ProbabilisticTensorDictModule, while the last module may or may not be probabilistic. In all cases, it exposes the get_dist() method to recover the distribution object from the ProbabilisticTensorDictModule instances in the sequence.

Multiple probabilistic modules can co-exist in a single ProbabilisticTensorDictSequential. If return_composite is False (default), only the last one will produce a distribution and the others will be executed as regular TensorDictModule instances. However, if a ProbabilisticTensorDictModule is not the last module in the sequence and return_composite=False, a ValueError will be raised when trying to query the module. If return_composite=True, all intermediate ProbabilisticTensorDictModule instances will contribute to a single CompositeDistribution instance.

Resulting log-probabilities will be conditional probabilities if samples are interdependent: whenever

Z=F(X,Y)

then the log-probability of Z will be

log(p(z|x,y))
Parameters:

*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule) – An ordered sequence of TensorDictModule instances, usually terminating in a ProbabilisticTensorDictModule, to be run sequentially. The modules can be instances of TensorDictModuleBase or any other function that matches this signature. Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, and thus will not affect the in_keys and out_keys attributes of the TensorDictSequential.

Keyword Arguments:
  • partial_tolerant (bool, optional) – If True, the input tensordict can miss some of the input keys. If so, only the modules that can be executed given the keys that are present will be executed. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is True AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. Defaults to False.

  • return_composite (bool, optional) –

    If True and multiple ProbabilisticTensorDictModule or ProbabilisticTensorDictSequential instances are found, a CompositeDistribution instance will be used. Otherwise, only the last module will be used to build the distribution. Defaults to False.

    Warning

    The behaviour of return_composite will change in v0.9 and default to True from there on.

Raises:

Examples

>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq
>>> import torch
>>> # Typical usage: a single distribution is computed last in the sequence
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq,         ...     TensorDictModule as Mod
>>> torch.manual_seed(0)
>>>
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
    fields={
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions are ignored when return_composite=False
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
...     Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     return_composite=False,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
    fields={
        loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions produce a CompositeDistribution when return_composite=True
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
...     Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     return_composite=True,
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
    fields={
        loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td))
TensorDict(
    fields={
        sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when
>>> # return_composite=True
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]),
...     return_composite=True,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
    fields={
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td))
TensorDict(
    fields={
        sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
build_dist_from_params(tensordict: TensorDictBase) Distribution

Constructs a distribution from the input parameters without evaluating other modules in the sequence.

This method searches for the last ProbabilisticTensorDictModule in the sequence and uses it to build the distribution.

Parameters:

tensordict (TensorDictBase) – The input tensordict containing the distribution parameters.

Returns:

The constructed distribution object.

Return type:

D.Distribution

Raises:

RuntimeError – If no ProbabilisticTensorDictModule is found in the sequence.

property default_interaction_type

Returns the default_interaction_type of the module using an iterative heuristic.

This property iterates over all modules in reverse order, attempting to retrieve the default_interaction_type attribute from any child module. The first non-None value encountered is returned. If no such value is found, a default interaction_type() is returned.

property dist_params_keys: List[NestedKey]

Returns all the keys pointing at the distribution params.

property dist_sample_keys: List[NestedKey]

Returns all the keys pointing at the distribution samples.

forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs) TensorDictBase

When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.

get_dist(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) Distribution

Returns the distribution resulting from passing the input tensordict through the sequence.

If return_composite is False (default), this method will only consider the last probabilistic module in the sequence.

Otherwise, it will return a CompositeDistribution instance containing the distributions of all probabilistic modules.

Parameters:
  • tensordict (TensorDictBase) – The input tensordict.

  • tensordict_out (TensorDictBase, optional) – The output tensordict. If None, a new tensordict will be created. Defaults to None.

Keyword Arguments:

**kwargs – Additional keyword arguments passed to the underlying modules.

Returns:

The resulting distribution object.

Return type:

D.Distribution

Raises:

RuntimeError – If no probabilistic module is found in the sequence.

Note

When return_composite is True, the distributions are conditioned on the previous samples in the sequence. This means that if a module depends on the output of a previous probabilistic module, its distribution will be conditional.

get_dist_params(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) tuple[torch.distributions.distribution.Distribution, tensordict.base.TensorDictBase]

Returns the distribution parameters and output tensordict.

This method runs the deterministic part of the ProbabilisticTensorDictSequential module to obtain the distribution parameters. The interaction type is set to the current global interaction type if available, otherwise it defaults to the interaction type of the last module.

Parameters:
  • tensordict (TensorDictBase) – The input tensordict.

  • tensordict_out (TensorDictBase, optional) – The output tensordict. If None, a new tensordict will be created. Defaults to None.

Keyword Arguments:

**kwargs – Additional keyword arguments passed to the deterministic part of the module.

Returns:

A tuple containing the distribution object and the output tensordict.

Return type:

tuple[D.Distribution, TensorDictBase]

Note

The interaction type is temporarily set to the specified value during the execution of this method.

log_prob(tensordict, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs)

Returns the log-probability of the input tensordict.

If self.return_composite is True and the distribution is a CompositeDistribution, this method will return the log-probability of the entire composite distribution.

Otherwise, it will only consider the last probabilistic module in the sequence.

Parameters:
  • tensordict (TensorDictBase) – The input tensordict.

  • tensordict_out (TensorDictBase, optional) – The output tensordict. If None, a new tensordict will be created. Defaults to None.

Keyword Arguments:

dist (torch.distributions.Distribution, optional) – The distribution object. If None, it will be computed using get_dist. Defaults to None.

Returns:

The log-probability of the input tensordict.

Return type:

TensorDictBase or torch.Tensor

Warning

In future releases (v0.9), the default values of aggregate_probabilities, inplace, and include_sum will change. To avoid warnings, it is recommended to explicitly pass these arguments to the log_prob method or set them in the constructor.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources