# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey, unravel_key
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
_reduce,
default_value_kwargs,
distance_loss,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
[docs]class DDPGLoss(LossModule):
"""The DDPG Loss class.
Args:
actor_network (TensorDictModule): a policy operator.
value_network (TensorDictModule): a Q value operator.
loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for
data collection. Default is ``False``.
delay_value (bool, optional): whether to separate the target value networks from the value networks used for
data collection. Default is ``True``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
parameters for both policy and critic losses.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.ddpg import DDPGLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> value = ValueOperator(
... module=module,
... in_keys=["observation", "action"])
>>> loss = DDPGLoss(actor, value)
>>> batch = [2, ]
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": spec.rand(batch),
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
fields={
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["next_reward", "next_done", "next_terminated"]`` + in_keys of the actor_network and value_network.
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]``
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.ddpg import DDPGLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> value = ValueOperator(
... module=module,
... in_keys=["observation", "action"])
>>> loss = DDPGLoss(actor, value)
>>> loss_actor, loss_value, pred_value, target_value, pred_value_max, target_value_max = loss(
... observation=torch.randn(n_obs),
... action=spec.rand(),
... next_done=torch.zeros(1, dtype=torch.bool),
... next_terminated=torch.zeros(1, dtype=torch.bool),
... next_observation=torch.randn(n_obs),
... next_reward=torch.randn(1))
>>> loss_actor.backward()
The output keys can also be filtered using the :meth:`DDPGLoss.select_out_keys`
method.
Examples:
>>> loss.select_out_keys('loss_actor', 'loss_value')
>>> loss_actor, loss_value = loss(
... observation=torch.randn(n_obs),
... action=spec.rand(),
... next_done=torch.zeros(1, dtype=torch.bool),
... next_terminated=torch.zeros(1, dtype=torch.bool),
... next_observation=torch.randn(n_obs),
... next_reward=torch.randn(1))
>>> loss_actor.backward()
"""
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
Attributes:
state_action_value (NestedKey): The input tensordict key where the
state action value is expected. Will be used for the underlying
value estimator as value key. Defaults to ``"state_action_value"``.
priority (NestedKey): The input tensordict key where the target
priority is written to. Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Will be used for the underlying value estimator.
Defaults to ``"terminated"``.
"""
state_action_value: NestedKey = "state_action_value"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator: ValueEstimators = ValueEstimators.TD0
out_keys = [
"loss_actor",
"loss_value",
"pred_value",
"target_value",
"pred_value_max",
"target_value_max",
]
actor_network: TensorDictModule
value_network: actor_network
actor_network_params: TensorDictParams
value_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_value_network_params: TensorDictParams
def __init__(
self,
actor_network: TensorDictModule,
value_network: TensorDictModule,
*,
loss_function: str = "l2",
delay_actor: bool = False,
delay_value: bool = True,
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
) -> None:
self._in_keys = None
if reduction is None:
reduction = "mean"
super().__init__()
self.delay_actor = delay_actor
self.delay_value = delay_value
actor_critic = ActorCriticWrapper(actor_network, value_network)
params = TensorDict.from_module(actor_critic)
params_meta = params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
)
with params_meta.to_module(actor_critic):
self.__dict__["actor_critic"] = deepcopy(actor_critic)
self.convert_to_functional(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor_network.parameters())
else:
policy_params = None
self.convert_to_functional(
value_network,
"value_network",
create_target_params=self.delay_value,
compare_against=policy_params,
)
self.actor_critic.module[0] = self.actor_network
self.actor_critic.module[1] = self.value_network
self.actor_in_keys = actor_network.in_keys
self.value_exclusive_keys = set(self.value_network.in_keys) - (
set(self.actor_in_keys) | set(self.actor_network.out_keys)
)
self.loss_function = loss_function
self.reduction = reduction
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value=self._tensor_keys.state_action_value,
reward=self._tensor_keys.reward,
done=self._tensor_keys.done,
terminated=self._tensor_keys.terminated,
)
self._set_in_keys()
def _set_in_keys(self):
in_keys = {
unravel_key(("next", self.tensor_keys.reward)),
unravel_key(("next", self.tensor_keys.done)),
unravel_key(("next", self.tensor_keys.terminated)),
*self.actor_in_keys,
*[unravel_key(("next", key)) for key in self.actor_in_keys],
*self.value_network.in_keys,
*[unravel_key(("next", key)) for key in self.value_network.in_keys],
}
self._in_keys = sorted(in_keys, key=str)
@property
def in_keys(self):
if self._in_keys is None:
self._set_in_keys()
return self._in_keys
@in_keys.setter
def in_keys(self, values):
self._in_keys = values
[docs] @dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
"""Computes the DDPG losses given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
a priority to items in the tensordict.
Args:
tensordict (TensorDictBase): a tensordict with keys ["done", "terminated", "reward"] and the in_keys of the actor
and value networks.
Returns:
a tuple of 2 tensors containing the DDPG loss.
"""
loss_value, metadata = self.loss_value(tensordict)
loss_actor, metadata_actor = self.loss_actor(tensordict)
metadata.update(metadata_actor)
td_out = TensorDict(
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
batch_size=[],
)
self._clear_weakrefs(
tensordict,
td_out,
"value_network_params",
"target_value_network_params",
"target_actor_network_params",
"actor_network_params",
)
return td_out
def loss_actor(
self,
tensordict: TensorDictBase,
) -> [torch.Tensor, dict]:
td_copy = tensordict.select(
*self.actor_in_keys, *self.value_exclusive_keys, strict=False
).detach()
with self.actor_network_params.to_module(self.actor_network):
td_copy = self.actor_network(td_copy)
with self._cached_detached_value_params.to_module(self.value_network):
td_copy = self.value_network(td_copy)
loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
metadata = {}
loss_actor = _reduce(loss_actor, self.reduction)
self._clear_weakrefs(
tensordict,
loss_actor,
"value_network_params",
"target_value_network_params",
"target_actor_network_params",
"actor_network_params",
)
return loss_actor, metadata
def loss_value(
self,
tensordict: TensorDictBase,
) -> tuple[torch.Tensor, dict]:
# value loss
td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
with self.value_network_params.to_module(self.value_network):
self.value_network(td_copy)
pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
target_value = self.value_estimator.value_estimate(
tensordict, target_params=self._cached_target_params
).squeeze(-1)
# td_error = pred_val - target_value
loss_value = distance_loss(
pred_val, target_value, loss_function=self.loss_function
)
td_error = (pred_val - target_value).pow(2)
td_error = td_error.detach()
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)
tensordict.set(
self.tensor_keys.priority,
td_error,
inplace=True,
)
with torch.no_grad():
metadata = {
"td_error": td_error,
"pred_value": pred_val,
"target_value": target_value,
"target_value_max": target_value.max(),
"pred_value_max": pred_val.max(),
}
loss_value = _reduce(loss_value, self.reduction)
self._clear_weakrefs(
tensordict,
"value_network_params",
"target_value_network_params",
"target_actor_network_params",
"actor_network_params",
)
return loss_value, metadata
[docs] def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
value_type = self.default_value_estimator
self.value_type = value_type
hp = dict(default_value_kwargs(value_type))
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
hp.update(hyperparams)
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
elif value_type == ValueEstimators.GAE:
raise NotImplementedError(
f"Value type {value_type} it not implemented for loss {type(self)}."
)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(
value_network=self.actor_critic, **hp
)
else:
raise NotImplementedError(f"Unknown value type {value_type}")
tensor_keys = {
"value": self.tensor_keys.state_action_value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
"terminated": self.tensor_keys.terminated,
}
self._value_estimator.set_keys(**tensor_keys)
@property
@_cache_values
def _cached_target_params(self):
target_params = TensorDict(
{
"module": {
"0": self.target_actor_network_params,
"1": self.target_value_network_params,
}
},
batch_size=self.target_actor_network_params.batch_size,
device=self.target_actor_network_params.device,
)
return target_params
@property
@_cache_values
def _cached_detached_value_params(self):
return self.value_network_params.detach()