# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
[docs]class TD3BCLoss(LossModule):
r"""TD3+BC Loss Module.
Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to
Offline Reinforcement Learning" <https://arxiv.org/pdf/2106.06860>`.
This class incorporates two loss functions, executed sequentially within the `forward` method:
1. :meth:`~.qvalue_loss`
2. :meth:`~.actor_loss`
Users also have the option to call these functions directly in the same order if preferred.
actor_network (TensorDictModule): the actor to be trained
qvalue_network (TensorDictModule): a single Q-value network or a list of
Q-value networks.
If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
times. If a list of modules is passed, their
parameters will be stacked unless they share the same identity (in which case
the original parameter will be expanded).
.. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters
and all the parameters will be considered as untied.
Keyword Args:
bounds (tuple of float, optional): the bounds of the action space.
Exclusive with ``action_spec``. Either this or ``action_spec`` must
be provided.
action_spec (TensorSpec, optional): the action spec.
Exclusive with ``bounds``. Either this or ``bounds`` must be provided.
num_qvalue_nets (int, optional): Number of Q-value networks to be
trained. Default is ``2``.
policy_noise (:obj:`float`, optional): Standard deviation for the target
policy action noise. Default is ``0.2``.
noise_clip (:obj:`float`, optional): Clipping range value for the sampled
target policy action noise. Default is ``0.5``.
alpha (:obj:`float`, optional): Weight for the behavioral cloning loss.
Defaults to ``2.5``.
priority_key (str, optional): Key where to write the priority value
for prioritized replay buffers. Default is
loss_function (str, optional): loss function to be used for the Q-value.
Can be one of ``"smooth_l1"``, ``"l2"``,
``"l1"``, Default is ``"smooth_l1"``.
delay_actor (bool, optional): whether to separate the target actor
networks from the actor networks used for
data collection. Default is ``True``.
delay_qvalue (bool, optional): Whether to separate the target Q value
networks from the Q value networks used
for data collection. Default is ``True``.
spec (TensorSpec, optional): the action tensor spec. If not provided
and the target entropy is ``"auto"``, it will be retrieved from
the actor.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
parameters for both policy and critic losses.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.td3_bc import TD3BCLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
... spec=spec)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``.
>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.td3_bc import TD3BCLoss
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
... spec=spec)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
>>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward=torch.randn(*batch, 1),
... next_observation=torch.randn(*batch, n_obs))
>>> loss_actor.backward()
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
state_action_value (NestedKey): The input tensordict key where the state action value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_action_value"``.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Will be used for the underlying value estimator.
Defaults to ``"terminated"``.
action: NestedKey = "action"
state_action_value: NestedKey = "state_action_value"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
actor_network: TensorDictModule
qvalue_network: TensorDictModule
actor_network_params: TensorDictParams
qvalue_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_qvalue_network_params: TensorDictParams
def __init__(
actor_network: TensorDictModule,
qvalue_network: TensorDictModule | list[TensorDictModule],
action_spec: TensorSpec = None,
bounds: tuple[float] | None = None,
num_qvalue_nets: int = 2,
policy_noise: float = 0.2,
noise_clip: float = 0.5,
alpha: float = 2.5,
loss_function: str = "smooth_l1",
delay_actor: bool = True,
delay_qvalue: bool = True,
priority_key: str = None,
separate_losses: bool = False,
reduction: str = None,
) -> None:
if reduction is None:
reduction = "mean"
self._in_keys = None
self.delay_actor = delay_actor
self.delay_qvalue = delay_qvalue
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
policy_params = None
for p in self.parameters():
device = p.device
device = None
self.num_qvalue_nets = num_qvalue_nets
self.loss_function = loss_function
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.alpha = alpha
if not ((action_spec is not None) ^ (bounds is not None)):
raise ValueError(
"One of 'bounds' and 'action_spec' must be provided, "
f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
elif action_spec is not None:
if isinstance(action_spec, Composite):
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
action_container_shape = action_spec[
action_container_shape = action_spec.shape
action_spec = action_spec[self.tensor_keys.action][
(0,) * len(action_container_shape)
if not isinstance(action_spec, Bounded):
raise ValueError(
f"action_spec is not of type Bounded but {type(action_spec)}."
low = action_spec.space.low
high = action_spec.space.high
low, high = bounds
if not isinstance(low, torch.Tensor):
low = torch.tensor(low)
if not isinstance(high, torch.Tensor):
high = torch.tensor(high, device=low.device, dtype=low.dtype)
if (low > high).any():
raise ValueError("Got a low bound higher than a high bound.")
if device is not None:
low = low.to(device)
high = high.to(device)
self.register_buffer("max_action", high)
self.register_buffer("min_action", low)
self.reduction = reduction
def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
def _set_in_keys(self):
keys = [
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*[("next", key) for key in self.actor_network.in_keys],
self._in_keys = list(set(keys))
def in_keys(self):
if self._in_keys is None:
return self._in_keys
def in_keys(self, values):
self._in_keys = values
def _cached_detach_qvalue_network_params(self):
return self.qvalue_network_params.detach()
def _cached_stack_actor_params(self):
return torch.stack(
[self.actor_network_params, self.target_actor_network_params], 0
[docs] def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
"""Compute the actor loss.
The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates.
tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
are required for this to be computed.
Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"`
used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda
value, and the lambda value `"lmbd"` itself.
tensordict_actor_grad = tensordict.select(
*self.actor_network.in_keys, strict=False
with self.actor_network_params.to_module(self.actor_network):
tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
actor_loss_td = tensordict_actor_grad.select(
*self.qvalue_network.in_keys, strict=False
self.num_qvalue_nets, *tensordict_actor_grad.batch_size
) # for actor loss
state_action_value_actor = (
bc_loss = torch.nn.functional.mse_loss(
lmbd = self.alpha / state_action_value_actor[0].abs().mean().detach()
loss_actor = -lmbd * state_action_value_actor[0] + bc_loss
metadata = {
"state_action_value_actor": state_action_value_actor[0].detach(),
"bc_loss": bc_loss.detach(),
"lmbd": lmbd,
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, metadata
[docs] def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
"""Compute the q-value loss.
The q-value loss should be computed before the :meth:`~.actor_loss`.
tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
are required for this to be computed.
Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`.
tensordict = tensordict.clone(False)
act = tensordict.get(self.tensor_keys.action)
# computing early for reprod
noise = (torch.randn_like(act) * self.policy_noise).clamp(
-self.noise_clip, self.noise_clip
with torch.no_grad():
next_td_actor = step_mdp(tensordict).select(
*self.actor_network.in_keys, strict=False
) # next_observation ->
with self.target_actor_network_params.to_module(self.actor_network):
next_td_actor = self.actor_network(next_td_actor)
next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp(
self.min_action, self.max_action
next_val_td = next_td_actor.select(
*self.qvalue_network.in_keys, strict=False
self.num_qvalue_nets, *next_td_actor.batch_size
) # for next value estimation
next_target_q1q2 = (
# min over the next target qvalues
next_target_qvalue = next_target_q1q2.min(0)[0]
# set next target qvalues
("next", self.tensor_keys.state_action_value),
qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand(
# preditcted current qvalues
current_qvalue = (
# compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done))
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
td_error = (current_qvalue - target_value).pow(2)
loss_qval = distance_loss(
metadata = {
"td_error": td_error,
"next_state_value": next_target_qvalue.detach(),
"pred_value": current_qvalue.detach(),
"target_value": target_value.detach(),
loss_qval = _reduce(loss_qval, reduction=self.reduction)
return loss_qval, metadata
[docs] @dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""The forward method.
Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns
a tensordict with these values.
To see what keys are expected in the input tensordict and what keys are expected as output, check the
class's `"in_keys"` and `"out_keys"` attributes.
tensordict_save = tensordict
loss_actor, metadata_actor = self.actor_loss(tensordict)
loss_qval, metadata_value = self.qvalue_loss(tensordict_save)
self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
if not loss_qval.shape == loss_actor.shape:
raise RuntimeError(
f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
td_out = TensorDict(
"loss_actor": loss_actor,
"loss_qvalue": loss_qval,
return td_out
[docs] def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
value_type = self.default_value_estimator
self.value_type = value_type
hp = dict(default_value_kwargs(value_type))
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
# we do not need a value network bc the next state value is already passed
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(value_network=None, **hp)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(value_network=None, **hp)
elif value_type == ValueEstimators.GAE:
raise NotImplementedError(
f"Value type {value_type} it not implemented for loss {type(self)}."
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=None, **hp)
raise NotImplementedError(f"Unknown value type {value_type}")
tensor_keys = {
"value": self.tensor_keys.state_action_value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
"terminated": self.tensor_keys.terminated,