Shortcuts

ProbabilisticTensorDictModule

class tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)

A probabilistic TD Module.

ProbabilisticTensorDictModule is a non-parametric module representing a probability distribution. It reads the distribution parameters from an input TensorDict using the specified in_keys. The output is sampled given some rule, specified by the input default_interaction_type argument and the interaction_type() global function.

ProbabilisticTensorDictModule can be used to construct the distribution (through the get_dist() method) and/or sampling from this distribution (through a regular __call__() to the module).

A ProbabilisticTensorDictModule instance has two main features: - It reads and writes TensorDict objects - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed.

When the __call__ / forward method is called, a distribution is created, and a value computed (using the ‘mean’, ‘mode’, ‘median’ attribute or the ‘rsample’, ‘sample’ method). The sampling step is skipped if the supplied TensorDict has all of the desired key-value pairs already.

By default, ProbabilisticTensorDictModule distribution class is a Delta distribution, making ProbabilisticTensorDictModule a simple wrapper around a deterministic mapping function.

Parameters:
  • in_keys (NestedKey or list of NestedKey or dict) – key(s) that will be read from the input TensorDict and used to build the distribution. Importantly, if it’s an list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by the distribution class of interest, e.g. "loc" and "scale" for the Normal distribution and similar. If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the tensordict that will get match to the corresponding distribution keys.

  • out_keys (NestedKey or list of NestedKey) – keys where the sampled values will be written. Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.

  • default_interaction_mode (str, optional) – Deprecated keyword-only argument. Please use default_interaction_type instead.

  • default_interaction_type (InteractionType, optional) –

    keyword-only argument. Default method to be used to retrieve the output value. Should be one of InteractionType: MODE, MEDIAN, MEAN or RANDOM (in which case the value is sampled randomly from the distribution). Default is MODE.

    Note

    When a sample is drawn, the ProbabilisticTensorDictModule instance will first look for the interaction mode dictated by the interaction_type() global function. If this returns None (its default value), then the default_interaction_type of the ProbabilisticTDModule instance will be used. Note that DataCollectorBase instances will use set_interaction_type to tensordict.nn.InteractionType.RANDOM by default.

  • distribution_class (Type, optional) – keyword-only argument. A torch.distributions.Distribution class to be used for sampling. Default is tensordict.nn.distributions.Delta.

  • distribution_kwargs (dict, optional) – keyword-only argument. Keyword-argument pairs to be passed to the distribution.

  • return_log_prob (bool, optional) – keyword-only argument. If True, the log-probability of the distribution sample will be written in the tensordict with the key log_prob_key. Default is False.

  • log_prob_key (NestedKey, optional) – key where to write the log_prob if return_log_prob = True. Defaults to ‘sample_log_prob’.

  • cache_dist (bool, optional) – keyword-only argument. EXPERIMENTAL: if True, the parameters of the distribution (i.e. the output of the module) will be written to the tensordict along with the sample. Those parameters can be used to re-compute the original distribution later on (e.g. to compute the divergence between the distribution used to sample the action and the updated distribution in PPO). Default is False.

  • n_empirical_estimate (int, optional) – keyword-only argument. Number of samples to compute the empirical mean when it is not available. Defaults to 1000.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
...     TensorDictModule,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal, Independent
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
...     net, in_keys=["input", "hidden"], out_keys=["params"]
... )
>>> normal_params = TensorDictModule(
...     NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"]
... )
>>> def IndepNormal(**kwargs):
...     return Independent(Normal(**kwargs), 1)
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=IndepNormal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(
...     module, normal_params, prob_module
... )
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     _ = td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> with params.to_module(td_module):
...     dist = td_module.get_dist(td)
>>> print(dist)
Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1)
>>> # we can also apply the module to the TensorDict with vmap
>>> from torch import vmap
>>> params = params.expand(4)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase, tensordict_out: TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_dist(tensordict: TensorDictBase) Distribution

Creates a torch.distribution.Distribution instance with the parameters provided in the input tensordict.

log_prob(tensordict)

Writes the log-probability of the distribution sample.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources