Shortcuts

torchrl.modules package

TensorDict modules: Actors, exploration, value models and generative models

TorchRL offers a series of module wrappers aimed at making it easy to build RL models from the ground up. These wrappers are exclusively based on tensordict.nn.TensorDictModule and tensordict.nn.TensorDictSequential. They can loosely be split in three categories: policies (actors), including exploration strategies, value model and simulation models (in model-based contexts).

The main features are:

  • Integration of the specs in your model to ensure that the model output matches what your environment expects as input;

  • Probabilistic modules that can automatically sample from a chosen distribution and/or return the distribution of interest;

  • Custom containers for Q-Value learning, model-based agents and others.

TensorDictModules and SafeModules

TorchRL SafeModule allows you to check the you model output matches what is to be expected for the environment. This should be used whenever your model is to be recycled across multiple environments for instance, and when you want to make sure that the outputs (e.g. the action) always satisfies the bounds imposed by the environment. Here is an example of how to use that feature with the Actor class:

>>> env = GymEnv("Pendulum-v1")
>>> action_spec = env.action_spec
>>> model = nn.LazyLinear(action_spec.shape[-1])
>>> policy = Actor(model, in_keys=["observation"], spec=action_spec, safe=True)

The safe flag ensures that the output is always within the bounds of the action_spec domain: if the network output violates these bounds it will be projected (in a L1-manner) into the desired domain.

Actor(*args, **kwargs)

General class for deterministic actors in RL.

MultiStepActorWrapper(*args, **kwargs)

A wrapper around a multi-action actor.

SafeModule(*args, **kwargs)

tensordict.nn.TensorDictModule subclass that accepts a TensorSpec as argument to control the output domain.

SafeSequential(*args, **kwargs)

A safe sequence of TensorDictModules.

TanhModule(*args, **kwargs)

A Tanh module for deterministic policies with bounded action space.

Exploration wrappers and modules

To efficiently explore the environment, TorchRL proposes a series of modules that will override the action sampled by the policy by a noisier version. Their behavior is controlled by exploration_type(): if the exploration is set to ExplorationType.RANDOM, the exploration is active. In all other cases, the action written in the tensordict is simply the network output.

Note

Unlike other exploration modules, ConsistentDropoutModule uses the train/eval mode to comply with the regular Dropout API in PyTorch. The set_exploration_type() context manager will have no effect on this module.

AdditiveGaussianModule(*args, **kwargs)

Additive Gaussian PO module.

AdditiveGaussianWrapper(*args, **kwargs)

Additive Gaussian PO wrapper.

ConsistentDropoutModule(*args, **kwargs)

A TensorDictModule wrapper for ConsistentDropout.

EGreedyModule(*args, **kwargs)

Epsilon-Greedy exploration module.

EGreedyWrapper(*args, **kwargs)

[Deprecated] Epsilon-Greedy PO wrapper.

OrnsteinUhlenbeckProcessModule(*args, **kwargs)

Ornstein-Uhlenbeck exploration policy module.

OrnsteinUhlenbeckProcessWrapper(*args, **kwargs)

Ornstein-Uhlenbeck exploration policy wrapper.

Probabilistic actors

Some algorithms such as PPO require a probabilistic policy to be implemented. In TorchRL, these policies take the form of a model, followed by a distribution constructor.

Note

The choice of a probabilistic or regular actor class depends on the algorithm that is being implemented. On-policy algorithms usually require a probabilistic actor, off-policy usually have a deterministic actor with an extra exploration strategy. There are, however, many exceptions to this rule.

The model reads an input (typically some observation from the environment) and outputs the parameters of a distribution, while the distribution constructor reads these parameters and gets a random sample from the distribution and/or provides a torch.distributions.Distribution object.

>>> from tensordict.nn import NormalParamExtractor, TensorDictSequential, TensorDictModule
>>> from torchrl.modules import SafeProbabilisticModule
>>> from torchrl.envs import GymEnv
>>> from torch.distributions import Normal
>>> from torch import nn
>>>
>>> env = GymEnv("Pendulum-v1")
>>> action_spec = env.action_spec
>>> model = nn.Sequential(nn.LazyLinear(action_spec.shape[-1] * 2), NormalParamExtractor())
>>> # build the first module, which maps the observation on the mean and sd of the normal distribution
>>> model = TensorDictModule(model, in_keys=["observation"], out_keys=["loc", "scale"])
>>> # build the distribution constructor
>>> prob_module = SafeProbabilisticModule(
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=Normal,
...     return_log_prob=True,
...     spec=action_spec,
... )
>>> policy = TensorDictSequential(model, prob_module)
>>> # execute a rollout
>>> env.rollout(3, policy)

To facilitate the construction of probabilistic policies, we provide a dedicated ProbabilisticActor:

>>> from torchrl.modules import ProbabilisticActor
>>> policy = ProbabilisticActor(
...     model,
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=Normal,
...     return_log_prob=True,
...     spec=action_spec,
... )

which alleviates the need to specify a constructor and putting it with the module in a sequence.

Outputs of this policy will contain a "loc" and "scale" entries, an "action" sampled according to the normal distribution and the log-probability of this action.

ProbabilisticActor(*args, **kwargs)

General class for probabilistic actors in RL.

SafeProbabilisticModule(*args, **kwargs)

tensordict.nn.ProbabilisticTensorDictModule subclass that accepts a TensorSpec as argument to control the output domain.

SafeProbabilisticTensorDictSequential(*args, ...)

tensordict.nn.ProbabilisticTensorDictSequential subclass that accepts a TensorSpec as argument to control the output domain.

Q-Value actors

Q-Value actors are a type of policy that selects actions based on the maximum value (or “quality”) of a state-action pair. This value can be represented as a table or a function. For discrete action spaces with continuous states, it’s common to use a non-linear model like a neural network to represent this function.

QValueActor

The QValueActor class takes in a module and an action specification, and outputs the selected action and its corresponding value.

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> # Create a tensor dict with an observation
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # Define the action space
>>> action_spec = OneHot(4)
>>> # Create a linear module to output action values
>>> module = nn.Linear(3, 4)
>>> # Create a QValueActor instance
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> # Run the actor on the tensor dict
>>> qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

This will output a tensor dict with the selected action and its corresponding value.

Distributional Q-Learning

Distributional Q-learning is a variant of Q-learning that represents the value function as a probability distribution over possible values, rather than a single scalar value. This allows the agent to learn about the uncertainty in the environment and make more informed decisions. In TorchRL, distributional Q-learning is implemented using the DistributionalQValueActor class. This class takes in a module, an action specification, and a support vector, and outputs the selected action and its corresponding value distribution.

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> # Create a tensor dict with an observation
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> # Define the action space
>>> action_spec = OneHot(4)
>>> # Define the number of bins for the value distribution
>>> nbins = 3
>>> # Create an MLP module to output logits for the value distribution
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> # Create a DistributionalQValueActor instance
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> # Run the actor on the tensor dict
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

This will output a tensor dict with the selected action and its corresponding value distribution.

QValueActor(*args, **kwargs)

A Q-Value actor class.

QValueModule(*args, **kwargs)

Q-Value TensorDictModule for Q-value policies.

DistributionalQValueActor(*args, **kwargs)

A Distributional DQN actor class.

DistributionalQValueModule(*args, **kwargs)

Distributional Q-Value hook for Q-value policies.

Value operators and joined models

TorchRL provides a series of value operators that wrap value networks to soften the interface with the rest of the library. The basic building block is torchrl.modules.tensordict_module.ValueOperator: given an input state (and possibly action), it will automatically write a "state_value" (or "state_action_value") in the tensordict, depending on what the input is. As such, this class accounts for both value and quality networks. Three classes are also proposed to group together a policy and a value network. The ActorCriticOperator is an joined actor-quality network with shared parameters: it reads an observation, pass it through a common backbone, writes a hidden state, feeds this hidden state to the policy, then takes the hidden state and the action and provides the quality of the state-action pair. The ActorValueOperator is a joined actor-value network with shared parameters: it reads an observation, pass it through a common backbone, writes a hidden state, feeds this hidden state to the policy and value modules to output an action and a state value. Finally, the ActorCriticWrapper is a joined actor and value network without shared parameters. It is mainly intended as a replacement for ActorValueOperator when a script needs to account for both options.

>>> actor = make_actor()
>>> value = make_value()
>>> if shared_params:
...     common = make_common()
...     model = ActorValueOperator(common, actor, value)
... else:
...     model = ActorValueOperator(actor, value)
>>> policy = model.get_policy_operator()  # will work in both cases

ActorCriticOperator(*args, **kwargs)

Actor-critic operator.

ActorCriticWrapper(*args, **kwargs)

Actor-value operator without common module.

ActorValueOperator(*args, **kwargs)

Actor-value operator.

ValueOperator(*args, **kwargs)

General class for value functions in RL.

DecisionTransformerInferenceWrapper(*args, ...)

Inference Action Wrapper for the Decision Transformer.

Domain-specific TensorDict modules

These modules include dedicated solutions for MBRL or RLHF pipelines.

LMHeadActorValueOperator(*args, **kwargs)

Builds an Actor-Value operator from an huggingface-like *LMHeadModel.

WorldModelWrapper(*args, **kwargs)

World model wrapper.

Hooks

The Q-value hooks are used by the QValueActor and DistributionalQValueActor modules and those should be preferred in general as they are easier to create and use.

QValueHook(action_space[, var_nums, ...])

Q-Value hook for Q-value policies.

DistributionalQValueHook(action_space, support)

Distributional Q-Value hook for Q-value policies.

Models

TorchRL provides a series of useful “regular” (ie non-tensordict) nn.Module classes for RL usage.

Regular modules

BatchRenorm1d(num_features, *[, momentum, ...])

BatchRenorm Module (https://arxiv.org/abs/1702.03275).

ConsistentDropout([p])

Implements a Dropout variant with consistent dropout.

Conv3dNet(in_features, depth, num_cells, ...)

A 3D-convolutional neural network.

ConvNet(in_features, depth, num_cells, ...)

A convolutional neural network.

MLP(in_features, out_features, ...)

A multi-layer perceptron.

Squeeze2dLayer()

Squeezing layer for convolutional neural networks.

SqueezeLayer([dims])

Squeezing layer.

Algorithm-specific modules

These networks implement sub-networks that have shown to be useful for specific algorithms, such as DQN, DDPG or Dreamer.

DTActor(state_dim, action_dim[, ...])

Decision Transformer Actor class.

DdpgCnnActor(action_dim[, conv_net_kwargs, ...])

DDPG Convolutional Actor class.

DdpgCnnQNet([conv_net_kwargs, ...])

DDPG Convolutional Q-value class.

DdpgMlpActor(action_dim[, mlp_net_kwargs, ...])

DDPG Actor class.

DdpgMlpQNet([mlp_net_kwargs_net1, ...])

DDPG Q-value MLP class.

DecisionTransformer(state_dim, action_dim[, ...])

Online Decion Transformer.

DistributionalDQNnet(*args, **kwargs)

Distributional Deep Q-Network softmax layer.

DreamerActor(out_features[, depth, ...])

Dreamer actor network.

DuelingCnnDQNet(out_features[, ...])

Dueling CNN Q-network.

GRUCell(input_size, hidden_size[, bias, ...])

A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.

GRU(input_size, hidden_size[, num_layers, ...])

A PyTorch module for executing multiple steps of a multi-layer GRU.

GRUModule(*args, **kwargs)

An embedder for an GRU module.

LSTMCell(input_size, hidden_size[, bias, ...])

A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.

LSTM(input_size, hidden_size[, num_layers, ...])

A PyTorch module for executing multiple steps of a multi-layer LSTM.

LSTMModule(*args, **kwargs)

An embedder for an LSTM module.

ObsDecoder([channels, num_layers, ...])

Observation decoder network.

ObsEncoder([channels, num_layers, depth])

Observation encoder network.

OnlineDTActor(state_dim, action_dim[, ...])

Online Decision Transformer Actor class.

RSSMPosterior([hidden_dim, state_dim, scale_lb])

The posterior network of the RSSM.

RSSMPrior(action_spec[, hidden_dim, ...])

The prior network of the RSSM.

set_recurrent_mode([mode])

Context manager for setting RNNs recurrent mode.

recurrent_mode()

Returns the current sampling type.

Multi-agent-specific modules

These networks implement models that can be used in multi-agent contexts. They use vmap() to execute multiple networks all at once on the network inputs. Because the parameters are batched, initialization may differ from what is usually done with other PyTorch modules, see get_stateful_net() for more information.

MultiAgentNetBase(*, n_agents[, ...])

A base class for multi-agent networks.

MultiAgentMLP(n_agent_inputs, ...)

Mult-agent MLP.

MultiAgentConvNet(n_agents, centralized, ...)

Multi-agent CNN.

QMixer(state_shape, mixing_embed_dim, ...)

QMix mixer.

VDNMixer(n_agents, device)

Value-Decomposition Network mixer.

Exploration

Noisy linear layers are a popular way of exploring the environment without altering the actions, but by integrating the stochasticity in the weight configuration.

NoisyLinear(in_features, out_features[, ...])

Noisy Linear Layer.

NoisyLazyLinear(out_features[, bias, ...])

Noisy Lazy Linear Layer.

reset_noise(layer)

Resets the noise of noisy layers.

Planners

CEMPlanner(*args, **kwargs)

CEMPlanner Module.

MPCPlannerBase(*args, **kwargs)

MPCPlannerBase abstract Module.

MPPIPlanner(*args, **kwargs)

MPPI Planner Module.

Distributions

Some distributions are typically used in RL scripts.

Delta(param[, atol, rtol, batch_shape, ...])

Delta distribution.

IndependentNormal(loc, scale[, upscale, ...])

Implements a Normal distribution with location scaling.

NormalParamWrapper(operator[, ...])

A wrapper for normal distribution parameters.

TanhNormal(loc, scale[, upscale, low, high, ...])

Implements a TanhNormal distribution with location scaling.

TruncatedNormal(loc, scale[, upscale, low, ...])

Implements a Truncated Normal distribution with location scaling.

TanhDelta(param[, low, high, event_dims, ...])

Implements a Tanh transformed_in Delta distribution.

OneHotCategorical([logits, probs, grad_method])

One-hot categorical distribution.

MaskedCategorical([logits, probs, mask, ...])

MaskedCategorical distribution.

MaskedOneHotCategorical([logits, probs, ...])

MaskedCategorical distribution.

Ordinal(scores)

A discrete distribution for learning to sample from finite ordered sets.

OneHotOrdinal(scores)

The one-hot version of the Ordinal distribution.

Utils

The module utils include functionals used to do some custom mappings as well as a tool to build TensorDictPrimer instances from a given module.

mappings(key)

Given an input string, returns a surjective function f(x): R -> R^+.

inv_softplus(bias)

Inverse softplus function.

biased_softplus(bias[, min_val])

A biased softplus module.

get_primers_from_module(module)

Get all tensordict primers from all submodules of a module.

VmapModule(*args, **kwargs)

A TensorDictModule wrapper to vmap over the input.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources