Source code for torchrl.objectives.td3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
_reduce,
_vmap_func,
default_value_kwargs,
distance_loss,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
[docs]class TD3Loss(LossModule):
"""TD3 Loss module.
Args:
actor_network (TensorDictModule): the actor to be trained
qvalue_network (TensorDictModule): a single Q-value network that will
be multiplicated as many times as needed.
Keyword Args:
bounds (tuple of float, optional): the bounds of the action space.
Exclusive with action_spec. Either this or ``action_spec`` must
be provided.
action_spec (TensorSpec, optional): the action spec.
Exclusive with bounds. Either this or ``bounds`` must be provided.
num_qvalue_nets (int, optional): Number of Q-value networks to be
trained. Default is ``10``.
policy_noise (float, optional): Standard deviation for the target
policy action noise. Default is ``0.2``.
noise_clip (float, optional): Clipping range value for the sampled
target policy action noise. Default is ``0.5``.
priority_key (str, optional): Key where to write the priority value
for prioritized replay buffers. Default is
`"td_error"`.
loss_function (str, optional): loss function to be used for the Q-value.
Can be one of ``"smooth_l1"``, ``"l2"``,
``"l1"``, Default is ``"smooth_l1"``.
delay_actor (bool, optional): whether to separate the target actor
networks from the actor networks used for
data collection. Default is ``True``.
delay_qvalue (bool, optional): Whether to separate the target Q value
networks from the Q value networks used
for data collection. Default is ``True``.
spec (TensorSpec, optional): the action tensor spec. If not provided
and the target entropy is ``"auto"``, it will be retrieved from
the actor.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
parameters for both policy and critic losses.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.td3 import TD3Loss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
... spec=spec)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
fields={
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
next_state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.td3 import TD3Loss
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
... spec=spec)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
>>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward=torch.randn(*batch, 1),
... next_observation=torch.randn(*batch, n_obs))
>>> loss_actor.backward()
"""
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
Attributes:
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
state_action_value (NestedKey): The input tensordict key where the state action value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_action_value"``.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Will be used for the underlying value estimator.
Defaults to ``"terminated"``.
"""
action: NestedKey = "action"
state_action_value: NestedKey = "state_action_value"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
"loss_qvalue",
"pred_value",
"state_action_value_actor",
"next_state_value",
"target_value",
]
def __init__(
self,
actor_network: TensorDictModule,
qvalue_network: TensorDictModule,
*,
action_spec: TensorSpec = None,
bounds: Optional[Tuple[float]] = None,
num_qvalue_nets: int = 2,
policy_noise: float = 0.2,
noise_clip: float = 0.5,
loss_function: str = "smooth_l1",
delay_actor: bool = True,
delay_qvalue: bool = True,
gamma: float = None,
priority_key: str = None,
separate_losses: bool = False,
reduction: str = None,
) -> None:
if reduction is None:
reduction = "mean"
super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)
self.delay_actor = delay_actor
self.delay_qvalue = delay_qvalue
self.convert_to_functional(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
self.convert_to_functional(
qvalue_network,
"qvalue_network",
num_qvalue_nets,
create_target_params=self.delay_qvalue,
compare_against=policy_params,
)
for p in self.parameters():
device = p.device
break
else:
device = None
self.num_qvalue_nets = num_qvalue_nets
self.loss_function = loss_function
self.policy_noise = policy_noise
self.noise_clip = noise_clip
if not ((action_spec is not None) ^ (bounds is not None)):
raise ValueError(
"One of 'bounds' and 'action_spec' must be provided, "
f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
)
elif action_spec is not None:
if isinstance(action_spec, CompositeSpec):
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = action_spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = action_spec.shape
action_spec = action_spec[self.tensor_keys.action][
(0,) * len(action_container_shape)
]
if not isinstance(action_spec, BoundedTensorSpec):
raise ValueError(
f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}."
)
low = action_spec.space.low
high = action_spec.space.high
else:
low, high = bounds
if not isinstance(low, torch.Tensor):
low = torch.tensor(low)
if not isinstance(high, torch.Tensor):
high = torch.tensor(high, device=low.device, dtype=low.dtype)
if (low > high).any():
raise ValueError("Got a low bound higher than a high bound.")
if device is not None:
low = low.to(device)
high = high.to(device)
self.register_buffer("max_action", high)
self.register_buffer("min_action", low)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)
self.reduction = reduction
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value=self._tensor_keys.state_action_value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
terminated=self.tensor_keys.terminated,
)
self._set_in_keys()
def _set_in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.qvalue_network.in_keys,
]
self._in_keys = list(set(keys))
@property
def in_keys(self):
if self._in_keys is None:
self._set_in_keys()
return self._in_keys
@in_keys.setter
def in_keys(self, values):
self._in_keys = values
@property
@_cache_values
def _cached_detach_qvalue_network_params(self):
return self.qvalue_network_params.detach()
@property
@_cache_values
def _cached_stack_actor_params(self):
return torch.stack(
[self.actor_network_params, self.target_actor_network_params], 0
)
def actor_loss(self, tensordict):
tensordict_actor_grad = tensordict.select(
*self.actor_network.in_keys, strict=False
)
with self.actor_network_params.to_module(self.actor_network):
tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
actor_loss_td = tensordict_actor_grad.select(
*self.qvalue_network.in_keys, strict=False
).expand(
self.num_qvalue_nets, *tensordict_actor_grad.batch_size
) # for actor loss
state_action_value_actor = (
self._vmap_qvalue_network00(
actor_loss_td,
self._cached_detach_qvalue_network_params,
)
.get(self.tensor_keys.state_action_value)
.squeeze(-1)
)
loss_actor = -(state_action_value_actor[0])
metadata = {
"state_action_value_actor": state_action_value_actor.detach(),
}
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, metadata
def value_loss(self, tensordict):
tensordict = tensordict.clone(False)
act = tensordict.get(self.tensor_keys.action)
# computing early for reprod
noise = (torch.randn_like(act) * self.policy_noise).clamp(
-self.noise_clip, self.noise_clip
)
with torch.no_grad():
next_td_actor = step_mdp(tensordict).select(
*self.actor_network.in_keys, strict=False
) # next_observation ->
with self.target_actor_network_params.to_module(self.actor_network):
next_td_actor = self.actor_network(next_td_actor)
next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp(
self.min_action, self.max_action
)
next_td_actor.set(
self.tensor_keys.action,
next_action,
)
next_val_td = next_td_actor.select(
*self.qvalue_network.in_keys, strict=False
).expand(
self.num_qvalue_nets, *next_td_actor.batch_size
) # for next value estimation
next_target_q1q2 = (
self._vmap_qvalue_network00(
next_val_td,
self.target_qvalue_network_params,
)
.get(self.tensor_keys.state_action_value)
.squeeze(-1)
)
# min over the next target qvalues
next_target_qvalue = next_target_q1q2.min(0)[0]
# set next target qvalues
tensordict.set(
("next", self.tensor_keys.state_action_value),
next_target_qvalue.unsqueeze(-1),
)
qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand(
self.num_qvalue_nets,
*tensordict.batch_size,
)
# preditcted current qvalues
current_qvalue = (
self._vmap_qvalue_network00(
qval_td,
self.qvalue_network_params,
)
.get(self.tensor_keys.state_action_value)
.squeeze(-1)
)
# compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done))
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
td_error = (current_qvalue - target_value).pow(2)
loss_qval = distance_loss(
current_qvalue,
target_value.expand_as(current_qvalue),
loss_function=self.loss_function,
).sum(0)
metadata = {
"td_error": td_error,
"next_state_value": next_target_qvalue.detach(),
"pred_value": current_qvalue.detach(),
"target_value": target_value.detach(),
}
loss_qval = _reduce(loss_qval, reduction=self.reduction)
return loss_qval, metadata
[docs] @dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict_save = tensordict
loss_actor, metadata_actor = self.actor_loss(tensordict)
loss_qval, metadata_value = self.value_loss(tensordict_save)
tensordict_save.set(
self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
)
if not loss_qval.shape == loss_actor.shape:
raise RuntimeError(
f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
)
td_out = TensorDict(
source={
"loss_actor": loss_actor,
"loss_qvalue": loss_qval,
**metadata_actor,
**metadata_value,
},
batch_size=[],
)
return td_out
[docs] def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
value_type = self.default_value_estimator
self.value_type = value_type
hp = dict(default_value_kwargs(value_type))
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
hp.update(hyperparams)
# we do not need a value network bc the next state value is already passed
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(value_network=None, **hp)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(value_network=None, **hp)
elif value_type == ValueEstimators.GAE:
raise NotImplementedError(
f"Value type {value_type} it not implemented for loss {type(self)}."
)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=None, **hp)
else:
raise NotImplementedError(f"Unknown value type {value_type}")
tensor_keys = {
"value": self.tensor_keys.state_action_value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
"terminated": self.tensor_keys.terminated,
}
self._value_estimator.set_keys(**tensor_keys)