# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
from dataclasses import dataclass
from numbers import Number

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torch import Tensor

from import Composite
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

[docs]class REDQLoss(LossModule): """REDQ Loss module. REDQ (RANDOMIZED ENSEMBLED DOUBLE Q-LEARNING: LEARNING FAST WITHOUT A MODEL generalizes the idea of using an ensemble of Q-value functions to train a SAC-like algorithm. Args: actor_network (TensorDictModule): the actor to be trained qvalue_network (TensorDictModule): a single Q-value network or a list of Q-value networks. If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` times. If a list of modules is passed, their parameters will be stacked unless they share the same identity (in which case the original parameter will be expanded). .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters and all the parameters will be considered as untied. Keyword Args: num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is ``10``. sub_sample_len (int, optional): number of Q-value networks to be subsampled to evaluate the next state value Default is ``2``. loss_function (str, optional): loss function to be used for the Q-value. Can be one of ``"smooth_l1"``, ``"l2"``, ``"l1"``, Default is ``"smooth_l1"``. alpha_init (:obj:`float`, optional): initial entropy multiplier. Default is ``1.0``. min_alpha (:obj:`float`, optional): min value of alpha. Default is ``0.1``. max_alpha (:obj:`float`, optional): max value of alpha. Default is ``10.0``. action_spec (TensorSpec, optional): the action tensor spec. If not provided and the target entropy is ``"auto"``, it will be retrieved from the actor. fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is ``False``. target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used for data collection. Default is ``False``. gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. Default is ``False``. priority_key (str, optional): [Deprecated, use .set_keys() instead] Key where to write the priority value for prioritized replay buffers. Default is ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Examples: >>> import torch >>> from torch import nn >>> from import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = REDQLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ action_log_prob_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), next.state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``. Examples: >>> import torch >>> from torch import nn >>> from import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = REDQLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) >>> # filter output keys to "loss_actor", and "loss_qvalue" >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") >>> loss_actor, loss_qvalue = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_actor.backward() """ @dataclass class _AcceptedKeys: """Maintains default values for all configurable tensordict keys. This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their default values Attributes: value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. sample_log_prob (NestedKey): The input tensordict key where the sample log probability is expected. Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`, `"action_log_prob"` otherwise. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. state_action_value (NestedKey): The input tensordict key where the state action value is expected. Defaults to ``"state_action_value"``. reward (NestedKey): The input tensordict key where the reward is expected. Will be used for the underlying value estimator. Defaults to ``"reward"``. done (NestedKey): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. terminated (NestedKey): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. """ action: NestedKey = "action" value: NestedKey = "state_value" sample_log_prob: NestedKey | None = None priority: NestedKey = "td_error" state_action_value: NestedKey = "state_action_value" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" def __post_init__(self): if self.sample_log_prob is None: if composite_lp_aggregate(nowarn=True): self.sample_log_prob = "sample_log_prob" else: self.sample_log_prob = "action_log_prob" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 out_keys = [ "loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value", ] actor_network: TensorDictModule qvalue_network: TensorDictModule actor_network_params: TensorDictParams qvalue_network_params: TensorDictParams target_actor_network_params: TensorDictParams target_qvalue_network_params: TensorDictParams def __init__( self, actor_network: TensorDictModule, qvalue_network: TensorDictModule | list[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, loss_function: str = "smooth_l1", alpha_init: float = 1.0, min_alpha: float = 0.1, max_alpha: float = 10.0, action_spec=None, fixed_alpha: bool = False, target_entropy: str | Number = "auto", delay_qvalue: bool = True, gSDE: bool = False, gamma: float = None, priority_key: str = None, separate_losses: bool = False, reduction: str = None, ): if reduction is None: reduction = "mean" super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority_key=priority_key) self.convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, ) # let's make sure that actor_network has `return_log_prob` set to True self.actor_network.return_log_prob = True if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared policy_params = list(actor_network.parameters()) else: policy_params = None self.delay_qvalue = delay_qvalue self.convert_to_functional( qvalue_network, "qvalue_network", num_qvalue_nets, create_target_params=self.delay_qvalue, compare_against=policy_params, ) self.num_qvalue_nets = num_qvalue_nets self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1)) self.loss_function = loss_function try: device = next(self.parameters()).device except AttributeError: device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) self.register_buffer( "min_log_alpha", torch.tensor(min_alpha, device=device).log() ) self.register_buffer( "max_log_alpha", torch.tensor(max_alpha, device=device).log() ) self.fixed_alpha = fixed_alpha if fixed_alpha: self.register_buffer( "log_alpha", torch.tensor(math.log(alpha_init), device=device) ) else: self.register_parameter( "log_alpha", torch.nn.Parameter( torch.tensor( math.log(alpha_init), device=device, requires_grad=True ) ), ) self._target_entropy = target_entropy self._action_spec = action_spec self.target_entropy_buffer = None self.reduction = reduction self.gSDE = gSDE if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._make_vmap() def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_getdist = _vmap_func( self.actor_network, func="get_dist_params", randomness=self.vmap_randomness ) @property def target_entropy(self): target_entropy = self.target_entropy_buffer if target_entropy is None: delattr(self, "target_entropy_buffer") target_entropy = self._target_entropy action_spec = self._action_spec actor_network = self.actor_network device = next(self.parameters()).device if target_entropy == "auto": action_spec = ( action_spec if action_spec is not None else getattr(actor_network, "spec", None) ) if action_spec is None: raise RuntimeError( "Cannot infer the dimensionality of the action. Consider providing " "the target entropy explicitly or provide the spec of the " "action tensor in the actor network." ) if not isinstance(action_spec, Composite): action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 ): action_container_shape = action_spec[ self.tensor_keys.action[:-1] ].shape else: action_container_shape = action_spec.shape target_entropy = -float( action_spec[self.tensor_keys.action] .shape[len(action_container_shape) :] .numel() ) self.register_buffer( "target_entropy_buffer", torch.tensor(target_entropy, device=device) ) return self.target_entropy_buffer return target_entropy def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( value=self._tensor_keys.value, reward=self.tensor_keys.reward, done=self.tensor_keys.done, terminated=self.tensor_keys.terminated, ) self._set_in_keys() @property def alpha(self): with torch.no_grad(): return self.log_alpha.clamp(self.min_log_alpha, self.max_log_alpha).exp() def _set_in_keys(self): keys = [ self.tensor_keys.action, self.tensor_keys.sample_log_prob, ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, ] self._in_keys = list(set(keys)) @property def in_keys(self): if self._in_keys is None: self._set_in_keys() return self._in_keys @in_keys.setter def in_keys(self, values): self._in_keys = values @property @_cache_values def _cached_detach_qvalue_network_params(self): return self.qvalue_network_params.detach() def _qvalue_params_cat(self, selected_q_params): qvalue_params = [ self._cached_detach_qvalue_network_params, selected_q_params, self.qvalue_network_params, ], 0, ) return qvalue_params
[docs] @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys tensordict_select = "next", *obs_keys, self.tensor_keys.action, strict=False ) # We need to copy bc select does not copy sub-tds tensordict_select = tensordict_select.copy() selected_models_idx = torch.randperm(self.num_qvalue_nets)[ : self.sub_sample_len ].sort()[0] selected_q_params = self.target_qvalue_network_params[selected_models_idx] actor_params = torch.stack( [self.actor_network_params, self.target_actor_network_params], 0 ) tensordict_actor_grad = *obs_keys, strict=False ) # to avoid overwriting keys next_td_actor = step_mdp(tensordict_select).select( *self.actor_network.in_keys, strict=False ) # next_observation -> tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) with set_exploration_type(ExplorationType.RANDOM): if self.gSDE: tensordict_actor.set( "_eps_gSDE", torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), ) # vmap doesn't support sampling, so we take it out from the vmap td_params = self._vmap_getdist( tensordict_actor, actor_params, ) sample_key = self.tensor_keys.action sample_key_lp = self.tensor_keys.sample_log_prob tensordict_actor_dist = self.actor_network.build_dist_from_params(td_params) tensordict_actor.set(sample_key, tensordict_actor_dist.rsample()) tensordict_actor.set( sample_key_lp, tensordict_actor_dist.log_prob(tensordict_actor.get(sample_key)), ) # repeat tensordict_actor to match the qvalue size _actor_loss_td = ( tensordict_actor[0] .select(*self.qvalue_network.in_keys) .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) ) # for actor loss _qval_td =*self.qvalue_network.in_keys).expand( self.num_qvalue_nets, **self.qvalue_network.in_keys).batch_size, ) # for qvalue loss _next_val_td = ( tensordict_actor[1] .select(*self.qvalue_network.in_keys) .expand(self.sub_sample_len, *tensordict_actor[1].batch_size) ) # for next value estimation tensordict_qval = [ _actor_loss_td, _next_val_td, _qval_td, ], 0, ) # cat params tensordict_qval = self._vmap_qvalue_network00( tensordict_qval, self._qvalue_params_cat(selected_q_params), ) state_action_value = tensordict_qval.get( self.tensor_keys.state_action_value ).squeeze(-1) ( state_action_value_actor, next_state_action_value_qvalue, state_action_value_qvalue, ) = state_action_value.split( [self.num_qvalue_nets, self.sub_sample_len, self.num_qvalue_nets], dim=0, ) sample_log_prob = tensordict_actor.get( self.tensor_keys.sample_log_prob ).squeeze(-1) ( action_log_prob_actor, next_action_log_prob_qvalue, ) = sample_log_prob.unbind(0) loss_actor = -(state_action_value_actor - self.alpha * action_log_prob_actor) next_state_value = ( next_state_action_value_qvalue - self.alpha * next_action_log_prob_qvalue ) next_state_value = next_state_value.min(0)[0] tensordict_select.set( ("next", self.tensor_keys.value), next_state_value.unsqueeze(-1) ) target_value = self.value_estimator.value_estimate(tensordict_select).squeeze( -1 ) pred_val = state_action_value_qvalue td_error = (pred_val - target_value).pow(2) loss_qval = distance_loss( pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, ) tensordict.set(self.tensor_keys.priority, td_error.detach().max(0)[0]) loss_alpha = self._loss_alpha(sample_log_prob) if not loss_qval.shape == loss_actor.shape: raise RuntimeError( f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" ) td_out = TensorDict( { "loss_actor": loss_actor, "loss_qvalue": loss_qval, "loss_alpha": loss_alpha, "alpha": self.alpha.detach(), "entropy": -sample_log_prob.detach().mean(), "state_action_value_actor": state_action_value_actor.detach(), "action_log_prob_actor": action_log_prob_actor.detach(), "next.state_value": next_state_value.detach(), "target_value": target_value.detach(), }, [], ) td_out = td_out.named_apply( lambda name, value: _reduce(value, reduction=self.reduction) if name.startswith("loss_") else value, ) self._clear_weakrefs( tensordict, td_out, "actor_network_params", "qvalue_network_params", "target_actor_network_params", "target_qvalue_network_params", ) return td_out
def _loss_alpha(self, log_pi: Tensor) -> Tensor: if torch.is_grad_enabled() and not log_pi.requires_grad: raise RuntimeError( "expected log_pi to require gradient for the alpha loss)" ) if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter alpha_loss = -self._safe_log_alpha.exp() * ( log_pi.detach() + self.target_entropy ) else: # placeholder alpha_loss = torch.zeros_like(log_pi) return alpha_loss @property def _safe_log_alpha(self): log_alpha = self.log_alpha with torch.no_grad(): log_alpha_clamp = log_alpha.clamp(self.min_log_alpha, self.max_log_alpha) log_alpha_det = log_alpha.detach() return log_alpha - log_alpha_det + log_alpha_clamp
[docs] def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator self.value_type = value_type hp = dict(default_value_kwargs(value_type)) if hasattr(self, "gamma"): hp["gamma"] = self.gamma hp.update(hyperparams) # we do not need a value network bc the next state value is already passed if value_type == ValueEstimators.TD1: self._value_estimator = TD1Estimator(value_network=None, **hp) elif value_type == ValueEstimators.TD0: self._value_estimator = TD0Estimator(value_network=None, **hp) elif value_type == ValueEstimators.GAE: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=None, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") tensor_keys = { "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, } self._value_estimator.set_keys(**tensor_keys)


