tensordict.nn.distributions.CompositeDistribution¶
- class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: dict | None = None, extra_kwargs=None)¶
A composition of distributions.
Groups distributions together with the TensorDict interface. All methods (
log_prob
,cdf
,icdf
,rsample
,sample
etc.) will return a tensordict, possibly modified in-place if the input was a tensordict.- Parameters:
params (TensorDictBase) – a nested key-tensor map where the root entries point to the sample names, and the leaves are the distribution parameters. Entry names must match those of
distribution_map
.distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – indicated the distribution types to be used. The names of the distributions will match the names of the samples in the tensordict.
- Keyword Arguments:
name_map (Dict[NestedKey, NestedKey]]) – a dictionary representing where each sample should be written. If not provided, the key names from
distribution_map
will be used.extra_kwargs (Dict[NestedKey, Dict]) – a possibly incomplete dictionary of extra keyword arguments for the distributions to be built.
Note
In this distribution class, the batch-size of the input tensordict containing the params (
params
) is indicative of the batch_shape of the distribution. For instance, the"sample_log_prob"
entry resulting from a call tolog_prob
will be of the shape of the params (+ any supplementary batch dimension).Examples
>>> params = TensorDict({ ... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)}, ... ("nested", "disc"): {"logits": torch.randn(3, 10)} ... }, [3]) >>> dist = CompositeDistribution(params, ... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}) >>> sample = dist.sample((4,)) >>> sample = dist.log_prob(sample) >>> print(sample) TensorDict( fields={ cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False), disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)