ProbabilisticTensorDictModule
- class tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)
A probabilistic TD Module.
ProbabilisticTensorDictModule is a non-parametric module embedding a probability distribution constructor. It reads the distribution parameters from an input TensorDict using the specified in_keys and outputs a sample (loosely speaking) of the distribution.
The output “sample” is produced given some rule, specified by the input
default_interaction_type
argument and theinteraction_type()
global function.ProbabilisticTensorDictModule can be used to construct the distribution (through the
get_dist()
method) and/or sampling from this distribution (through a regular__call__()
to the module).A ProbabilisticTensorDictModule instance has two main features:
It reads and writes from and to TensorDict objects;
It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed.
When the
__call__()
andforward()
method are called, a distribution is created, and a value computed (depending on theinteraction_type
value, ‘dist.mean’, ‘dist.mode’, ‘dist.median’ attributes could be used, as well as the ‘dist.rsample’, ‘dist.sample’ method). The sampling step is skipped if the supplied TensorDict has all the desired key-value pairs already.By default, ProbabilisticTensorDictModule distribution class is a
Delta
distribution, making ProbabilisticTensorDictModule a simple wrapper around a deterministic mapping function.- Parameters:
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – key(s) that will be read from the input TensorDict and used to build the distribution. Importantly, if it’s a list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by the distribution class of interest, e.g.
"loc"
and"scale"
for theNormal
distribution and similar. If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the tensordict that will get match to the corresponding distribution keys.out_keys (NestedKey | List[NestedKey] | None) – key(s) where the sampled values will be written. Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.
- Keyword Arguments:
default_interaction_type (InteractionType, optional) –
keyword-only argument. Default method to be used to retrieve the output value. Should be one of InteractionType: MODE, MEDIAN, MEAN or RANDOM (in which case the value is sampled randomly from the distribution). Default is MODE.
Note
When a sample is drawn, the
ProbabilisticTensorDictModule
instance will first look for the interaction mode dictated by theinteraction_type()
global function. If this returns None (its default value), then the default_interaction_type of the ProbabilisticTDModule instance will be used. Note thatDataCollectorBase
instances will use set_interaction_type totensordict.nn.InteractionType.RANDOM
by default.Note
In some cases, the mode, median or mean value may not be readily available through the corresponding attribute. To paliate this,
ProbabilisticTensorDictModule
will first attempt to get the value through a call toget_mode()
,get_median()
orget_mean()
if the method exists.distribution_class (Type or Callable[[Any], Distribution], optional) –
keyword-only argument. A
torch.distributions.Distribution
class to be used for sampling. Default isDelta
.Note
If the distribution class is of type
CompositeDistribution
, theout_keys
can be inferred directly form the"distribution_map"
or"name_map"
keywork arguments provided through this class’distribution_kwargs
keyword argument, making theout_keys
optional in such cases.distribution_kwargs (dict, optional) –
keyword-only argument. Keyword-argument pairs to be passed to the distribution.
Note
if your kwargs contain tensors that you would like to transfer to device with the module, or tensors that should see their dtype modified when calling module.to(dtype), you can wrap the kwargs in a
TensorDictParams
to do this automatically.return_log_prob (bool, optional) – keyword-only argument. If
True
, the log-probability of the distribution sample will be written in the tensordict with the key log_prob_key. Default isFalse
.log_prob_keys (List[NestedKey], optional) –
keys where to write the log_prob if
return_log_prob=True
. Defaults to ‘<sample_key_name>_log_prob’, where <sample_key_name> is each of theout_keys
.Note
This is only available when
composite_lp_aggregate()
is set toFalse
.log_prob_key (NestedKey, optional) –
key where to write the log_prob if
return_log_prob=True
. Defaults to ‘sample_log_prob’ whencomposite_lp_aggregate()
is set to True or ‘<sample_key_name>_log_prob’ otherwise.Note
When there is more than one sample, this is only available when
composite_lp_aggregate()
is set toTrue
.cache_dist (bool, optional) – keyword-only argument. EXPERIMENTAL: if
True
, the parameters of the distribution (i.e. the output of the module) will be written to the tensordict along with the sample. Those parameters can be used to re-compute the original distribution later on (e.g. to compute the divergence between the distribution used to sample the action and the updated distribution in PPO). Default isFalse
.n_empirical_estimate (int, optional) – keyword-only argument. Number of samples to compute the empirical mean when it is not available. Defaults to 1000.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, ... TensorDictModule, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal, Independent >>> td = TensorDict( ... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3] ... ) >>> net = torch.nn.GRUCell(4, 8) >>> module = TensorDictModule( ... net, in_keys=["input", "hidden"], out_keys=["params"] ... ) >>> normal_params = TensorDictModule( ... NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"] ... ) >>> def IndepNormal(**kwargs): ... return Independent(Normal(**kwargs), 1) >>> prob_module = ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=IndepNormal, ... return_log_prob=True, ... ) >>> td_module = ProbabilisticTensorDictSequential( ... module, normal_params, prob_module ... ) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... _ = td_module(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> with params.to_module(td_module): ... dist = td_module.get_dist(td) >>> print(dist) Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1) >>> # we can also apply the module to the TensorDict with vmap >>> from torch import vmap >>> params = params.expand(4) >>> def func(td, params): ... with params.to_module(td_module): ... return td_module(td) >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)
- build_dist_from_params(tensordict: TensorDictBase) Distribution
Creates a
torch.distribution.Distribution
instance with the parameters provided in the input tensordict.- Parameters:
tensordict (TensorDictBase) – The input tensordict containing the distribution parameters.
- Returns:
A
torch.distribution.Distribution
instance created from the input tensordict.- Raises:
TypeError – If the input tensordict does not match the distribution keywords.
- property dist_params_keys: List[NestedKey]
Returns all the keys pointing at the distribution params.
- property dist_sample_keys: List[NestedKey]
Returns all the keys pointing at the distribution samples.
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_dist(tensordict: TensorDictBase) Distribution
Creates a
torch.distribution.Distribution
instance with the parameters provided in the input tensordict.- Parameters:
tensordict (TensorDictBase) – The input tensordict containing the distribution parameters.
- Returns:
A
torch.distribution.Distribution
instance created from the input tensordict.- Raises:
TypeError – If the input tensordict does not match the distribution keywords.
- log_prob(tensordict, *, dist: Optional[Distribution] = None)
Computes the log-probability of the distribution sample.
- Parameters:
tensordict (TensorDictBase) – The input tensordict containing the distribution parameters.
dist (torch.distributions.Distribution, optional) – The distribution instance. Defaults to
None
. IfNone
, the distribution will be computed using the get_dist method.
- Returns:
A tensor representing the log-probability of the distribution sample.