SafeProbabilisticModule¶
- class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[source]¶
tensordict.nn.ProbabilisticTensorDictModule
subclass that accepts aTensorSpec
as argument to control the output domain.SafeProbabilisticModule is a non-parametric module representing a probability distribution. It reads the distribution parameters from an input TensorDict using the specified in_keys. The output is sampled given some rule, specified by the input
default_interaction_type
argument and theinteraction_type()
global function.SafeProbabilisticModule
can be used to construct the distribution (through theget_dist()
method) and/or sampling from this distribution (through a regular__call__()
to the module).A
SafeProbabilisticModule
instance has two main features: - It reads and writes TensorDict objects - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed.When the
__call__
/forward
method is called, a distribution is created, and a value computed (using the ‘mean’, ‘mode’, ‘median’ attribute or the ‘rsample’, ‘sample’ method). The sampling step is skipped if the supplied TensorDict has all of the desired key-value pairs already.By default, SafeProbabilisticModule distribution class is a Delta distribution, making SafeProbabilisticModule a simple wrapper around a deterministic mapping function.
- Parameters:
in_keys (NestedKey or list of NestedKey or dict) – key(s) that will be read from the input TensorDict and used to build the distribution. Importantly, if it’s an list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by the distribution class of interest, e.g.
"loc"
and"scale"
for the Normal distribution and similar. If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the tensordict that will get match to the corresponding distribution keys.out_keys (NestedKey or list of NestedKey) – keys where the sampled values will be written. Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.
spec (TensorSpec) – specs of the first output tensor. Used when calling td_module.random() to generate random values in the target space.
safe (bool, optional) – if
True
, the value of the sample is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. As for thespec
argument, this check will only occur for the distribution sample, but not the other tensors returned by the input module. If the sample is out of bounds, it is projected back onto the desired space using the TensorSpec.project method. Default isFalse
.default_interaction_type (str, optional) – default method to be used to retrieve the output value. Should be one of: ‘mode’, ‘median’, ‘mean’ or ‘random’ (in which case the value is sampled randomly from the distribution). Default is ‘mode’. Note: When a sample is drawn, the
ProbabilisticTDModule
instance will fist look for the interaction mode dictated by the interaction_typ() global function. If this returns None (its default value), then the default_interaction_type of theProbabilisticTDModule
instance will be used. Note that DataCollector instances will usetensordict.nn.set_interaction_type()
totensordict.nn.InteractionType.RANDOM
by default.distribution_class (Type, optional) – a torch.distributions.Distribution class to be used for sampling. Default is Delta.
distribution_kwargs (dict, optional) – kwargs to be passed to the distribution.
return_log_prob (bool, optional) – if
True
, the log-probability of the distribution sample will be written in the tensordict with the key ‘sample_log_prob’. Default isFalse
.log_prob_key (NestedKey, optional) – key where to write the log_prob if return_log_prob = True. Defaults to ‘sample_log_prob’.
cache_dist (bool, optional) – EXPERIMENTAL: if
True
, the parameters of the distribution (i.e. the output of the module) will be written to the tensordict along with the sample. Those parameters can be used to re-compute the original distribution later on (e.g. to compute the divergence between the distribution used to sample the action and the updated distribution in PPO). Default isFalse
.n_empirical_estimate (int, optional) – number of samples to compute the empirical mean when it is not available. Default is 1000
- random(tensordict: TensorDictBase) TensorDictBase [source]¶
Samples a random element in the target space, irrespective of any input.
If multiple output keys are present, only the first will be written in the input
tensordict
.- Parameters:
tensordict (TensorDictBase) – tensordict where the output value should be written.
- Returns:
the original tensordict with a new/updated value for the output key.
- random_sample(tensordict: TensorDictBase) TensorDictBase [source]¶
See
SafeModule.random(...)
.