Shortcuts

tensordict.nn.distributions.CompositeDistribution

class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: Optional[NestedKey] = None, entropy_key: Optional[NestedKey] = None)

A composite distribution that groups multiple distributions together using the TensorDict interface.

This class allows for operations such as log_prob_composite, entropy_composite, cdf, icdf, rsample, and sample to be performed on a collection of distributions, returning a TensorDict. The input TensorDict may be modified in-place.

Parameters:
  • params (TensorDictBase) – A nested key-tensor map where the root entries correspond to sample names, and the leaves are the distribution parameters. Entry names must match those specified in distribution_map.

  • distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – Specifies the distribution types to be used. The names of the distributions should match the sample names in the TensorDict.

Keyword Arguments:
  • name_map (Dict[NestedKey, NestedKey], optional) – A mapping of where each sample should be written. If not provided, the key names from distribution_map will be used.

  • extra_kwargs (Dict[NestedKey, Dict], optional) – A dictionary of additional keyword arguments for constructing the distributions.

  • aggregate_probabilities (bool, optional) –

    If True, the log_prob and entropy methods will sum the probabilities and entropies of the individual distributions and return a single tensor. If False, individual log-probabilities will be stored in the input TensorDict (for log_prob) or returned as leaves of the output TensorDict (for entropy). This can be overridden at runtime by passing the aggregate_probabilities argument to log_prob and entropy. Defaults to False.

    Warning

    This argument will be deprecated in v0.9 when tensordict.nn.probabilistic.composite_lp_aggregate() will default to False.

  • log_prob_key (NestedKey, optional) –

    The key where the aggregated log probability will be stored. Defaults to ‘sample_log_prob’.

    Note

    if tensordict.nn.probabilistic.composite_lp_aggregate() returns False, tbe log-probabilities will be written under (“path”, “to”, “leaf”, “<sample_name>_log_prob”) where (“path”, “to”, “leaf”, “<sample_name>”) is the NestedKey corresponding to the leaf tensor being sampled. In that case, the log_prob_key argument will be ignored.

  • entropy_key (NestedKey, optional) –

    The key where the entropy will be stored. Defaults to ‘entropy’

    Note

    if tensordict.nn.probabilistic.composite_lp_aggregate() returns False, tbe entropies will be written under (“path”, “to”, “leaf”, “<sample_name>_entropy”) where (“path”, “to”, “leaf”, “<sample_name>”) is the NestedKey corresponding to the leaf tensor being sampled. In that case, the entropy_key argument will be ignored.

Note

The batch size of the input TensorDict containing the parameters (params) determines the batch shape of the distribution. For example, the “sample_log_prob” entry resulting from a call to log_prob will have the shape of the parameters plus any additional batch dimensions.

See also

ProbabilisticTensorDictModule and ProbabilisticTensorDictSequential to learn how to use this class as part of a model.

See also

set_composite_lp_aggregate to control the aggregation of the log-probabilities.

Examples

>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> with set_composite_lp_aggregate(False):
...     sample = dist.log_prob(sample)
...     print(sample)
TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources