Shortcuts

ProbabilisticActor

class torchrl.modules.tensordict_module.ProbabilisticActor(*args, **kwargs)[source]

General class for probabilistic actors in RL.

The Actor class comes with default values for the out_keys ([“action”]) and if the spec is provided but not as a CompositeSpec object, it will be automatically translated into spec = CompositeSpec(action=spec)

Parameters:
  • module (nn.Module) – a torch.nn.Module used to map the input to the output parameter space.

  • in_keys (str or iterable of str or dict) – key(s) that will be read from the input TensorDict and used to build the distribution. Importantly, if it’s an iterable of string or a string, those keys must match the keywords used by the distribution class of interest, e.g. "loc" and "scale" for the Normal distribution and similar. If in_keys is a dictionary,, the keys are the keys of the distribution and the values are the keys in the tensordict that will get match to the corresponding distribution keys.

  • out_keys (str or iterable of str) – keys where the sampled values will be written. Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.

  • spec (TensorSpec, optional) – keyword-only argument containing the specs of the output tensor. If the module outputs multiple output tensors, spec characterize the space of the first output tensor.

  • safe (bool) – keyword-only argument. if True, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the desired space using the TensorSpec.project method. Default is False.

  • default_interaction_type=InteractionType.RANDOM (str, optional) – keyword-only argument. Default method to be used to retrieve the output value. Should be one of: ‘mode’, ‘median’, ‘mean’ or ‘random’ (in which case the value is sampled randomly from the distribution). Default is ‘mode’. Note: When a sample is drawn, the ProbabilisticTDModule instance will first look for the interaction mode dictated by the interaction_typ() global function. If this returns None (its default value), then the default_interaction_type of the ProbabilisticTDModule instance will be used. Note that DataCollector instances will use set_interaction_type to tensordict.nn.InteractionType.RANDOM by default.

  • distribution_class (Type, optional) – keyword-only argument. A torch.distributions.Distribution class to be used for sampling. Default is tensordict.nn.distributions.Delta.

  • distribution_kwargs (dict, optional) – keyword-only argument. Keyword-argument pairs to be passed to the distribution.

  • return_log_prob (bool, optional) – keyword-only argument. If True, the log-probability of the distribution sample will be written in the tensordict with the key ‘sample_log_prob’. Default is False.

  • cache_dist (bool, optional) – keyword-only argument. EXPERIMENTAL: if True, the parameters of the distribution (i.e. the output of the module) will be written to the tensordict along with the sample. Those parameters can be used to re-compute the original distribution later on (e.g. to compute the divergence between the distribution used to sample the action and the updated distribution in PPO). Default is False.

  • n_empirical_estimate (int, optional) – keyword-only argument. Number of samples to compute the empirical mean when it is not available. Defaults to 1000.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, make_functional
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> action_spec = BoundedTensorSpec(shape=torch.Size([4]),
...    minimum=-1, maximum=1)
>>> module = NormalParamWrapper(torch.nn.Linear(4, 8))
>>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
>>> td_module = ProbabilisticActor(
...    module=tensordict_module,
...    spec=action_spec,
...    in_keys=["loc", "scale"],
...    distribution_class=TanhNormal,
...    )
>>> params = make_functional(td_module)
>>> td = td_module(td, params=params)
>>> td
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources