Shortcuts

TD3BCLoss

class torchrl.objectives.TD3BCLoss(*args, **kwargs)[source]

TD3+BC Loss Module.

Implementation of the TD3+BC loss presented in the paper “A Minimalist Approach to Offline Reinforcement Learning” <https://arxiv.org/pdf/2106.06860>.

This class incorporates two loss functions, executed sequentially within the forward method:

  1. qvalue_loss()

  2. actor_loss()

Users also have the option to call these functions directly in the same order if preferred.

Parameters:
  • actor_network (TensorDictModule) – the actor to be trained

  • qvalue_network (TensorDictModule) –

    a single Q-value network or a list of Q-value networks. If a single instance of qvalue_network is provided, it will be duplicated num_qvalue_nets times. If a list of modules is passed, their parameters will be stacked unless they share the same identity (in which case the original parameter will be expanded).

    Warning

    When a list of parameters if passed, it will __not__ be compared against the policy parameters and all the parameters will be considered as untied.

Keyword Arguments:
  • bounds (tuple of float, optional) –

    the bounds of the action space.

    Exclusive with action_spec. Either this or action_spec must

    be provided.

  • action_spec (TensorSpec, optional) – the action spec. Exclusive with bounds. Either this or bounds must be provided.

  • num_qvalue_nets (int, optional) – Number of Q-value networks to be trained. Default is 2.

  • policy_noise (float, optional) – Standard deviation for the target policy action noise. Default is 0.2.

  • noise_clip (float, optional) – Clipping range value for the sampled target policy action noise. Default is 0.5.

  • alpha (float, optional) – Weight for the behavioral cloning loss. Defaults to 2.5.

  • priority_key (str, optional) – Key where to write the priority value for prioritized replay buffers. Default is “td_error”.

  • loss_function (str, optional) – loss function to be used for the Q-value. Can be one of "smooth_l1", "l2", "l1", Default is "smooth_l1".

  • delay_actor (bool, optional) – whether to separate the target actor networks from the actor networks used for data collection. Default is True.

  • delay_qvalue (bool, optional) – Whether to separate the target Q value networks from the Q value networks used for data collection. Default is True.

  • spec (TensorSpec, optional) – the action tensor spec. If not provided and the target entropy is "auto", it will be retrieved from the actor.

  • separate_losses (bool, optional) – if True, shared parameters between policy and critic will only be trained on the policy loss. Defaults to False, i.e., gradients are propagated to shared parameters for both policy and critic losses.

  • reduction (str, optional) – Specifies the reduction to apply to the output: "none" | "mean" | "sum". "none": no reduction will be applied, "mean": the sum of the output will be divided by the number of elements in the output, "sum": the output will be summed. Default: "mean".

Examples

>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.td3_bc import TD3BCLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
...     module=module,
...     spec=spec)
>>> class ValueClass(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.Linear(n_obs + n_act, 1)
...     def forward(self, obs, act):
...         return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
...     module=module,
...     in_keys=['observation', 'action'])
>>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
...      "observation": torch.randn(*batch, n_obs),
...      "action": action,
...      ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
...      ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
...      ("next", "reward"): torch.randn(*batch, 1),
...      ("next", "observation"): torch.randn(*batch, n_obs),
...  }, batch)
>>> loss(data)
TensorDict(
    fields={
        bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: ["action", "next_reward", "next_done", "next_terminated"] + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: ["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",].

Examples

>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.td3_bc import TD3BCLoss
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
...     module=module,
...     spec=spec)
>>> class ValueClass(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.Linear(n_obs + n_act, 1)
...     def forward(self, obs, act):
...         return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
...     module=module,
...     in_keys=['observation', 'action'])
>>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
>>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvalue = loss(
...         observation=torch.randn(*batch, n_obs),
...         action=action,
...         next_done=torch.zeros(*batch, 1, dtype=torch.bool),
...         next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
...         next_reward=torch.randn(*batch, 1),
...         next_observation=torch.randn(*batch, n_obs))
>>> loss_actor.backward()
actor_loss(tensordict) Tuple[Tensor, dict][source]

Compute the actor loss.

The actor loss should be computed after the qvalue_loss() and is usually delayed 1-3 critic updates.

Parameters:

tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.

Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached “bc_loss”

used in the combined actor loss as well as the detached “state_action_value_actor” used to calculate the lambda value, and the lambda value “lmbd” itself.

forward(tensordict: TensorDictBase = None) TensorDictBase[source]

The forward method.

Computes successively the actor_loss(), qvalue_loss(), and returns a tensordict with these values. To see what keys are expected in the input tensordict and what keys are expected as output, check the class’s “in_keys” and “out_keys” attributes.

make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]

Value-function constructor.

If the non-default value function is wanted, it must be built using this method.

Parameters:
  • value_type (ValueEstimators) – A ValueEstimators enum type indicating the value function to use. If none is provided, the default stored in the default_value_estimator attribute will be used. The resulting value estimator class will be registered in self.value_type, allowing future refinements.

  • **hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by default_value_kwargs() will be used.

Examples

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
qvalue_loss(tensordict) Tuple[Tensor, dict][source]

Compute the q-value loss.

The q-value loss should be computed before the actor_loss().

Parameters:

tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.

Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing

the detached “td_error” to be used for prioritized sampling, the detached “next_state_value”, the detached “pred_value”, and the detached “target_value”.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources