Shortcuts

torchrl.objectives package

TorchRL provides a series of losses to use in your training scripts. The aim is to have losses that are easily reusable/swappable and that have a simple signature.

The main characteristics of TorchRL losses are:

  • they are stateful objects: they contain a copy of the trainable parameters such that loss_module.parameters() gives whatever is needed to train the algorithm.

  • They follow the tensordict convention: the torch.nn.Module.forward() method will receive a tensordict as input that contains all the necessary information to return a loss value.

  • They output a tensordict.TensorDict instance with the loss values written under a "loss_<smth>" where smth is a string describing the loss. Additional keys in the tensordict may be useful metrics to log during training time.

Note

The reason we return independent losses is to let the user use a different optimizer for different sets of parameters for instance. Summing the losses can be simply done via

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

Training value functions

TorchRL provides a range of value estimators such as TD(0), TD(1), TD(\(\lambda\)) and GAE. In a nutshell, a value estimator is a function of data (mostly rewards and done states) and a state value (ie. the value returned by a function that is fit to estimate state-values). To learn more about value estimators, check the introduction to RL from Sutton and Barto, in particular the chapters about value iteration and TD learning. It gives a somewhat biased estimation of the discounted return following a state or a state-action pair based on data and proxy maps. These estimators are used in two contexts:

  • To train the value network to learn the “true” state value (or state-action value) map, one needs a target value to fit it to. The better (less bias, less variance) the estimator, the better the value network will be, which in turn can speed up the policy training significantly. Typically, the value network loss will look like:

    >>> value = value_network(states)
    >>> target_value = value_estimator(rewards, done, value_network(next_state))
    >>> value_net_loss = (value - target_value).pow(2).mean()
    
  • Computing an “advantage” signal for policy-optimization. The advantage is the delta between the value estimate (from the estimator, ie from “real” data) and the output of the value network (ie the proxy to this value). A positive advantage can be seen as a signal that the policy actually performed better than expected, thereby signaling that there is room for improvement if that trajectory is to be taken as example. Conversely, a negative advantage signifies that the policy underperformed compared to what was to be expected.

Thins are not always as easy as in the example above and the formula to compute the value estimator or the advantage may be slightly more intricate than this. To help users flexibly use one or another value estimator, we provide a simple API to change it on-the-fly. Here is an example with DQN, but all modules will follow a similar structure:

>>> from torchrl.objectives import DQNLoss, ValueEstimators
>>> loss_module = DQNLoss(actor)
>>> kwargs = {"gamma": 0.9, "lmbda": 0.9}
>>> loss_module.make_value_estimator(ValueEstimators.TDLambda, **kwargs)

The ValueEstimators class enumerates the value estimators to choose from. This makes it easy for the users to rely on auto-completion to make their choice.

LossModule(*args, **kwargs)

A parent class for RL losses.

DQN

DQNLoss(*args, **kwargs)

The DQN Loss class.

DistributionalDQNLoss(*args, **kwargs)

A distributional DQN loss class.

DDPG

DDPGLoss(*args, **kwargs)

The DDPG Loss class.

SAC

SACLoss(*args, **kwargs)

TorchRL implementation of the SAC loss.

DiscreteSACLoss(*args, **kwargs)

Discrete SAC Loss module.

REDQ

REDQLoss(*args, **kwargs)

REDQ Loss module.

IQL

IQLLoss(*args, **kwargs)

TorchRL implementation of the IQL loss.

DiscreteIQLLoss(*args, **kwargs)

TorchRL implementation of the discrete IQL loss.

CQL

CQLLoss(*args, **kwargs)

TorchRL implementation of the continuous CQL loss.

DiscreteCQLLoss(*args, **kwargs)

TorchRL implementation of the discrete CQL loss.

DT

DTLoss(*args, **kwargs)

TorchRL implementation of the Online Decision Transformer loss.

OnlineDTLoss(*args, **kwargs)

TorchRL implementation of the Online Decision Transformer loss.

TD3

TD3Loss(*args, **kwargs)

TD3 Loss module.

PPO

PPOLoss(*args, **kwargs)

A parent PPO loss class.

ClipPPOLoss(*args, **kwargs)

Clipped PPO loss.

KLPENPPOLoss(*args, **kwargs)

KL Penalty PPO loss.

A2C

A2CLoss(*args, **kwargs)

TorchRL implementation of the A2C loss.

Reinforce

ReinforceLoss(*args, **kwargs)

Reinforce loss module.

Dreamer

DreamerActorLoss(*args, **kwargs)

Dreamer Actor Loss.

DreamerModelLoss(*args, **kwargs)

Dreamer Model Loss.

DreamerValueLoss(*args, **kwargs)

Dreamer Value Loss.

Multi-agent objectives

These objectives are specific to multi-agent algorithms.

QMixer

QMixerLoss(*args, **kwargs)

The QMixer loss class.

Returns

ValueEstimatorBase(*args, **kwargs)

An abstract parent class for value function modules.

TD0Estimator(*args, **kwargs)

Temporal Difference (TD(0)) estimate of advantage function.

TD1Estimator(*args, **kwargs)

\(\infty\)-Temporal Difference (TD(1)) estimate of advantage function.

TDLambdaEstimator(*args, **kwargs)

TD(\(\lambda\)) estimate of advantage function.

GAE(*args, **kwargs)

A class wrapper around the generalized advantage estimate functional.

functional.td0_return_estimate(gamma, ...[, ...])

TD(0) discounted return estimate of a trajectory.

functional.td0_advantage_estimate(gamma, ...)

TD(0) advantage estimate of a trajectory.

functional.td1_return_estimate(gamma, ...[, ...])

TD(1) return estimate.

functional.vec_td1_return_estimate(gamma, ...)

Vectorized TD(1) return estimate.

functional.td1_advantage_estimate(gamma, ...)

TD(1) advantage estimate.

functional.vec_td1_advantage_estimate(gamma, ...)

Vectorized TD(1) advantage estimate.

functional.td_lambda_return_estimate(gamma, ...)

TD(\(\lambda\)) return estimate.

functional.vec_td_lambda_return_estimate(...)

Vectorized TD(\(\lambda\)) return estimate.

functional.td_lambda_advantage_estimate(...)

TD(\(\lambda\)) advantage estimate.

functional.vec_td_lambda_advantage_estimate(...)

Vectorized TD(\(\lambda\)) advantage estimate.

functional.generalized_advantage_estimate(...)

Generalized advantage estimate of a trajectory.

functional.vec_generalized_advantage_estimate(...)

Vectorized Generalized advantage estimate of a trajectory.

functional.reward2go(reward, done, gamma, *)

Compute the discounted cumulative sum of rewards given multiple trajectories and the episode ends.

Utils

distance_loss(v1, v2, loss_function[, ...])

Computes a distance loss between two tensors.

hold_out_net(network)

Context manager to hold a network out of a computational graph.

hold_out_params(params)

Context manager to hold a list of parameters out of a computational graph.

next_state_value(tensordict[, operator, ...])

Computes the next state value (without gradient) to compute a target value.

SoftUpdate(loss_module, *[, eps, tau])

A soft-update class for target network update in Double DQN/DDPG.

HardUpdate(loss_module, *[, ...])

A hard-update class for target network update in Double DQN/DDPG (by contrast with soft updates).

ValueEstimators(value)

Value function enumerator for custom-built estimators.

default_value_kwargs(value_type)

Default value function keyword argument generator.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources