Shortcuts

tensordict.nn.distributions.CompositeDistribution

class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: dict | None = None, extra_kwargs=None)

A composition of distributions.

Groups distributions together with the TensorDict interface. All methods (log_prob, cdf, icdf, rsample, sample etc.) will return a tensordict, possibly modified in-place if the input was a tensordict.

Parameters:
  • params (TensorDictBase) – a nested key-tensor map where the root entries point to the sample names, and the leaves are the distribution parameters. Entry names must match those of distribution_map.

  • distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – indicated the distribution types to be used. The names of the distributions will match the names of the samples in the tensordict.

Keyword Arguments:
  • name_map (Dict[NestedKey, NestedKey]]) – a dictionary representing where each sample should be written. If not provided, the key names from distribution_map will be used.

  • extra_kwargs (Dict[NestedKey, Dict]) – a possibly incomplete dictionary of extra keyword arguments for the distributions to be built.

Note

In this distribution class, the batch-size of the input tensordict containing the params (params) is indicative of the batch_shape of the distribution. For instance, the "sample_log_prob" entry resulting from a call to log_prob will be of the shape of the params (+ any supplementary batch dimension).

Examples

>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> sample = dist.log_prob(sample)
>>> print(sample)
TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources