tensordict.nn.distributions.CompositeDistribution¶
- class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: NestedKey = 'sample_log_prob', entropy_key: NestedKey = 'entropy')¶
A composition of distributions.
Groups distributions together with the TensorDict interface. Methods (
log_prob_composite
,entropy_composite
,cdf
,icdf
,rsample
,sample
etc.) will return a tensordict, possibly modified in-place if the input was a tensordict.- Parameters:
params (TensorDictBase) – a nested key-tensor map where the root entries point to the sample names, and the leaves are the distribution parameters. Entry names must match those of
distribution_map
.distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – indicated the distribution types to be used. The names of the distributions will match the names of the samples in the tensordict.
- Keyword Arguments:
name_map (Dict[NestedKey, NestedKey]]) – a dictionary representing where each sample should be written. If not provided, the key names from
distribution_map
will be used.extra_kwargs (Dict[NestedKey, Dict]) – a possibly incomplete dictionary of extra keyword arguments for the distributions to be built.
aggregate_probabilities (bool) – if
True
, thelog_prob()
andentropy()
methods will sum the probabilities and entropies of the individual distributions and return a single tensor. IfFalse
, the single log-probabilities will be registered in the input tensordict (forlog_prob()
) or retuned as leaves of the output tensordict (forentropy()
). This parameter can be overridden at runtime by passing theaggregate_probabilities
argument tolog_prob
andentropy
. Defaults toFalse
.log_prob_key (NestedKey, optional) – key where to write the log_prob. Defaults to ‘sample_log_prob’.
entropy_key (NestedKey, optional) – key where to write the entropy. Defaults to ‘entropy’.
Note
In this distribution class, the batch-size of the input tensordict containing the params (
params
) is indicative of the batch_shape of the distribution. For instance, the"sample_log_prob"
entry resulting from a call tolog_prob
will be of the shape of the params (+ any supplementary batch dimension).Examples
>>> params = TensorDict({ ... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)}, ... ("nested", "disc"): {"logits": torch.randn(3, 10)} ... }, [3]) >>> dist = CompositeDistribution(params, ... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}) >>> sample = dist.sample((4,)) >>> sample = dist.log_prob(sample) >>> print(sample) TensorDict( fields={ cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False), disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)