Shortcuts

Source code for torchrl.objectives.ddpg

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey, unravel_key

from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
    _cache_values,
    _GAMMA_LMBDA_DEPREC_ERROR,
    _reduce,
    default_value_kwargs,
    distance_loss,
    ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


[docs]class DDPGLoss(LossModule): """The DDPG Loss class. Args: actor_network (TensorDictModule): a policy operator. value_network (TensorDictModule): a Q value operator. loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for data collection. Default is ``False``. delay_value (bool, optional): whether to separate the target value networks from the value networks used for data collection. Default is ``True``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> value = ValueOperator( ... module=module, ... in_keys=["observation", "action"]) >>> loss = DDPGLoss(actor, value) >>> batch = [2, ] >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": spec.rand(batch), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: ``["next_reward", "next_done", "next_terminated"]`` + in_keys of the actor_network and value_network. The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]`` Examples: >>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> value = ValueOperator( ... module=module, ... in_keys=["observation", "action"]) >>> loss = DDPGLoss(actor, value) >>> loss_actor, loss_value, pred_value, target_value, pred_value_max, target_value_max = loss( ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward() The output keys can also be filtered using the :meth:`DDPGLoss.select_out_keys` method. Examples: >>> loss.select_out_keys('loss_actor', 'loss_value') >>> loss_actor, loss_value = loss( ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward() """ @dataclass class _AcceptedKeys: """Maintains default values for all configurable tensordict keys. This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their default values. Attributes: state_action_value (NestedKey): The input tensordict key where the state action value is expected. Will be used for the underlying value estimator as value key. Defaults to ``"state_action_value"``. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. Will be used for the underlying value estimator. Defaults to ``"reward"``. done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. terminated (NestedKey): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. """ state_action_value: NestedKey = "state_action_value" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator: ValueEstimators = ValueEstimators.TD0 out_keys = [ "loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max", ] actor_network: TensorDictModule value_network: actor_network actor_network_params: TensorDictParams value_network_params: TensorDictParams target_actor_network_params: TensorDictParams target_value_network_params: TensorDictParams def __init__( self, actor_network: TensorDictModule, value_network: TensorDictModule, *, loss_function: str = "l2", delay_actor: bool = False, delay_value: bool = True, gamma: float = None, separate_losses: bool = False, reduction: str = None, ) -> None: self._in_keys = None if reduction is None: reduction = "mean" super().__init__() self.delay_actor = delay_actor self.delay_value = delay_value actor_critic = ActorCriticWrapper(actor_network, value_network) params = TensorDict.from_module(actor_critic) params_meta = params.apply( self._make_meta_params, device=torch.device("meta"), filter_empty=False ) with params_meta.to_module(actor_critic): self.__dict__["actor_critic"] = deepcopy(actor_critic) self.convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, ) if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared policy_params = list(actor_network.parameters()) else: policy_params = None self.convert_to_functional( value_network, "value_network", create_target_params=self.delay_value, compare_against=policy_params, ) self.actor_critic.module[0] = self.actor_network self.actor_critic.module[1] = self.value_network self.actor_in_keys = actor_network.in_keys self.value_exclusive_keys = set(self.value_network.in_keys) - ( set(self.actor_in_keys) | set(self.actor_network.out_keys) ) self.loss_function = loss_function self.reduction = reduction if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( value=self._tensor_keys.state_action_value, reward=self._tensor_keys.reward, done=self._tensor_keys.done, terminated=self._tensor_keys.terminated, ) self._set_in_keys() def _set_in_keys(self): in_keys = { unravel_key(("next", self.tensor_keys.reward)), unravel_key(("next", self.tensor_keys.done)), unravel_key(("next", self.tensor_keys.terminated)), *self.actor_in_keys, *[unravel_key(("next", key)) for key in self.actor_in_keys], *self.value_network.in_keys, *[unravel_key(("next", key)) for key in self.value_network.in_keys], } self._in_keys = sorted(in_keys, key=str) @property def in_keys(self): if self._in_keys is None: self._set_in_keys() return self._in_keys @in_keys.setter def in_keys(self, values): self._in_keys = values
[docs] @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDict: """Computes the DDPG losses given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign a priority to items in the tensordict. Args: tensordict (TensorDictBase): a tensordict with keys ["done", "terminated", "reward"] and the in_keys of the actor and value networks. Returns: a tuple of 2 tensors containing the DDPG loss. """ loss_value, metadata = self.loss_value(tensordict) loss_actor, metadata_actor = self.loss_actor(tensordict) metadata.update(metadata_actor) td_out = TensorDict( source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, batch_size=[], ) self._clear_weakrefs( tensordict, td_out, "value_network_params", "target_value_network_params", "target_actor_network_params", "actor_network_params", ) return td_out
def loss_actor( self, tensordict: TensorDictBase, ) -> [torch.Tensor, dict]: td_copy = tensordict.select( *self.actor_in_keys, *self.value_exclusive_keys, strict=False ).detach() with self.actor_network_params.to_module(self.actor_network): td_copy = self.actor_network(td_copy) with self._cached_detached_value_params.to_module(self.value_network): td_copy = self.value_network(td_copy) loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1) metadata = {} loss_actor = _reduce(loss_actor, self.reduction) self._clear_weakrefs( tensordict, loss_actor, "value_network_params", "target_value_network_params", "target_actor_network_params", "actor_network_params", ) return loss_actor, metadata def loss_value( self, tensordict: TensorDictBase, ) -> tuple[torch.Tensor, dict]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() with self.value_network_params.to_module(self.value_network): self.value_network(td_copy) pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1) target_value = self.value_estimator.value_estimate( tensordict, target_params=self._cached_target_params ).squeeze(-1) # td_error = pred_val - target_value loss_value = distance_loss( pred_val, target_value, loss_function=self.loss_function ) td_error = (pred_val - target_value).pow(2) td_error = td_error.detach() if tensordict.device is not None: td_error = td_error.to(tensordict.device) tensordict.set( self.tensor_keys.priority, td_error, inplace=True, ) with torch.no_grad(): metadata = { "td_error": td_error, "pred_value": pred_val, "target_value": target_value, "target_value_max": target_value.max(), "pred_value_max": pred_val.max(), } loss_value = _reduce(loss_value, self.reduction) self._clear_weakrefs( tensordict, "value_network_params", "target_value_network_params", "target_actor_network_params", "actor_network_params", ) return loss_value, metadata
[docs] def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator self.value_type = value_type hp = dict(default_value_kwargs(value_type)) if hasattr(self, "gamma"): hp["gamma"] = self.gamma hp.update(hyperparams) if value_type == ValueEstimators.TD1: self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.TD0: self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.GAE: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator( value_network=self.actor_critic, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") tensor_keys = { "value": self.tensor_keys.state_action_value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys)
@property @_cache_values def _cached_target_params(self): target_params = TensorDict( { "module": { "0": self.target_actor_network_params, "1": self.target_value_network_params, } }, batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) return target_params @property @_cache_values def _cached_detached_value_params(self): return self.value_network_params.detach()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources