Shortcuts

A2CLoss

class torchrl.objectives.A2CLoss(*args, **kwargs)[source]

TorchRL implementation of the A2C loss.

A2C (Advantage Actor Critic) is a model-free, online RL algorithm that uses parallel rollouts of n steps to update the policy, relying on the REINFORCE estimator to compute the gradient. It also adds an entropy term to the objective function to improve exploration.

For more details regarding A2C, refer to: “Asynchronous Methods for Deep Reinforcment Learning”, https://arxiv.org/abs/1602.01783v2

Parameters:
  • actor_network (ProbabilisticTensorDictSequential) – policy operator.

  • critic_network (ValueOperator) – value operator.

  • entropy_bonus (bool) – if True, an entropy bonus will be added to the loss to favour exploratory policies.

  • samples_mc_entropy (int) – if the distribution retrieved from the policy operator does not have a closed form formula for the entropy, a Monte-Carlo estimate will be used. samples_mc_entropy will control how many samples will be used to compute this estimate. Defaults to 1.

  • entropy_coef (float) – the weight of the entropy loss.

  • critic_coef (float) – the weight of the critic loss.

  • loss_critic_type (str) – loss function for the value discrepancy. Can be one of “l1”, “l2” or “smooth_l1”. Defaults to "smooth_l1".

  • separate_losses (bool, optional) – if True, shared parameters between policy and critic will only be trained on the policy loss. Defaults to False, ie. gradients are propagated to shared parameters for both policy and critic losses.

  • advantage_key (str) – [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is expected to be written. default: “advantage”

  • value_target_key (str) – [Deprecated, use set_keys() instead] the input tensordict key where the target state value is expected to be written. Defaults to "value_target".

  • functional (bool, optional) – whether modules should be functionalized. Functionalizing permits features like meta-RL, but makes it impossible to use distributed models (DDP, FSDP, …) and comes with a little cost. Defaults to True.

  • reduction (str, optional) – Specifies the reduction to apply to the output: "none" | "mean" | "sum". "none": no reduction will be applied, "mean": the sum of the output will be divided by the number of elements in the output, "sum": the output will be summed. Default: "mean".

  • clip_value (float, optional) – If provided, it will be used to compute a clipped version of the value prediction with respect to the input value estimate and use it to calculate the value loss. The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training and preventing large updates. However, it will have no impact if the value estimate was done by the current version of the value estimator. Defaults to None.

Examples

>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.a2c import A2CLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
...     module=module,
...     in_keys=["loc", "scale"],
...     spec=spec,
...     distribution_class=TanhNormal)
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
...     module=module,
...     in_keys=["observation"])
>>> loss = A2CLoss(actor, value, loss_critic_type="l2")
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
...         "observation": torch.randn(*batch, n_obs),
...         "action": action,
...         ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
...         ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
...         ("next", "reward"): torch.randn(*batch, 1),
...         ("next", "observation"): torch.randn(*batch, n_obs),
...     }, batch)
>>> loss(data)
TensorDict(
    fields={
        entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: ["action", "next_reward", "next_done", "next_terminated"] + in_keys of the actor and critic. The return value is a tuple of tensors in the following order: ["loss_objective"] + ["loss_critic"] if critic_coef is not None + ["entropy", "loss_entropy"] if entropy_bonus is True and critic_coef is not None

Examples

>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.a2c import A2CLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
...     module=module,
...     in_keys=["loc", "scale"],
...     spec=spec,
...     distribution_class=TanhNormal)
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
...     module=module,
...     in_keys=["observation"])
>>> loss = A2CLoss(actor, value, loss_critic_type="l2")
>>> batch = [2, ]
>>> loss_obj, loss_critic, entropy, loss_entropy = loss(
...     observation = torch.randn(*batch, n_obs),
...     action = spec.rand(batch),
...     next_done = torch.zeros(*batch, 1, dtype=torch.bool),
...     next_terminated = torch.zeros(*batch, 1, dtype=torch.bool),
...     next_reward = torch.randn(*batch, 1),
...     next_observation = torch.randn(*batch, n_obs))
>>> loss_obj.backward()

The output keys can also be filtered using the SACLoss.select_out_keys() method.

Examples

>>> loss.select_out_keys('loss_objective', 'loss_critic')
>>> loss_obj, loss_critic = loss(
...     observation = torch.randn(*batch, n_obs),
...     action = spec.rand(batch),
...     next_done = torch.zeros(*batch, 1, dtype=torch.bool),
...     next_terminated = torch.zeros(*batch, 1, dtype=torch.bool),
...     next_reward = torch.randn(*batch, 1),
...     next_observation = torch.randn(*batch, n_obs))
>>> loss_obj.backward()
forward(tensordict: TensorDictBase) TensorDictBase[source]

It is designed to read an input TensorDict and return another tensordict with loss keys named “loss*”.

Splitting the loss in its component can then be used by the trainer to log the various loss values throughout training. Other scalars present in the output tensordict will be logged too.

Parameters:

tensordict – an input tensordict with the values required to compute the loss.

Returns:

A new tensordict with no batch dimension containing various loss scalars which will be named “loss*”. It is essential that the losses are returned with this name as they will be read by the trainer before backpropagation.

make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]

Value-function constructor.

If the non-default value function is wanted, it must be built using this method.

Parameters:
  • value_type (ValueEstimators) – A ValueEstimators enum type indicating the value function to use. If none is provided, the default stored in the default_value_estimator attribute will be used. The resulting value estimator class will be registered in self.value_type, allowing future refinements.

  • **hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by default_value_kwargs() will be used.

Examples

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources