Shortcuts

Source code for torchrl.objectives.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import re
import warnings
from enum import Enum
from typing import Iterable, List, Optional, Union

import torch
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
from tensordict.nn import TensorDictModule
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import dropout

try:
    from torch import vmap
except ImportError as err:
    try:
        from functorch import vmap
    except ImportError as err_ft:
        raise err_ft from err
from torchrl.envs.utils import step_mdp

try:
    from torch.compiler import is_dynamo_compiling
except ImportError:
    from torch._dynamo import is_compiling as is_dynamo_compiling

_GAMMA_LMBDA_DEPREC_ERROR = (
    "Passing gamma / lambda parameters through the loss constructor "
    "is a deprecated feature. To customize your value function, "
    "run `loss_module.make_value_estimator(ValueEstimators.<value_fun>, gamma=val)`."
)

RANDOM_MODULE_LIST = (dropout._DropoutNd,)


[docs]class ValueEstimators(Enum): """Value function enumerator for custom-built estimators. Allows for a flexible usage of various value functions when the loss module allows it. Examples: >>> dqn_loss = DQNLoss(actor) >>> dqn_loss.make_value_estimator(ValueEstimators.TD0, gamma=0.9) """ TD0 = "Bootstrapped TD (1-step return)" TD1 = "TD(1) (infinity-step return)" TDLambda = "TD(lambda)" GAE = "Generalized advantage estimate" VTrace = "V-trace"
[docs]def default_value_kwargs(value_type: ValueEstimators): """Default value function keyword argument generator. Args: value_type (Enum.value): the value function type, from the :class:`~torchrl.objectives.utils.ValueEstimators` class. Examples: >>> kwargs = default_value_kwargs(ValueEstimators.TDLambda) {"gamma": 0.99, "lmbda": 0.95} """ if value_type == ValueEstimators.TD1: return {"gamma": 0.99, "differentiable": True} elif value_type == ValueEstimators.TD0: return {"gamma": 0.99, "differentiable": True} elif value_type == ValueEstimators.GAE: return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} elif value_type == ValueEstimators.TDLambda: return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} elif value_type == ValueEstimators.VTrace: return {"gamma": 0.99, "differentiable": True} else: raise NotImplementedError(f"Unknown value type {value_type}.")
class _context_manager: def __init__(self, value=True): self.value = value self.prev = [] def __call__(self, func): @functools.wraps(func) def decorate_context(*args, **kwargs): with self: return func(*args, **kwargs) return decorate_context
[docs]def distance_loss( v1: torch.Tensor, v2: torch.Tensor, loss_function: str, strict_shape: bool = True, ) -> torch.Tensor: """Computes a distance loss between two tensors. Args: v1 (Tensor): a tensor with a shape compatible with v2 v2 (Tensor): a tensor with a shape compatible with v1 loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used. strict_shape (bool): if False, v1 and v2 are allowed to have a different shape. Default is ``True``. Returns: A tensor of the shape v1.view_as(v2) or v2.view_as(v1) with values equal to the distance loss between the two. """ if v1.shape != v2.shape and strict_shape: raise RuntimeError( f"The input tensors have shapes {v1.shape} and {v2.shape} which are incompatible." ) if loss_function == "l2": value_loss = F.mse_loss( v1, v2, reduction="none", ) elif loss_function == "l1": value_loss = F.l1_loss( v1, v2, reduction="none", ) elif loss_function == "smooth_l1": value_loss = F.smooth_l1_loss( v1, v2, reduction="none", ) else: raise NotImplementedError(f"Unknown loss {loss_function}") return value_loss
class TargetNetUpdater: """An abstract class for target network update in Double DQN/DDPG. Args: loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. """ def __init__( self, loss_module: "LossModule", # noqa: F821 ): from torchrl.objectives.common import LossModule if not isinstance(loss_module, LossModule): raise ValueError("The loss_module must be a LossModule instance.") _has_update_associated = getattr(loss_module, "_has_update_associated", None) for k in loss_module._has_update_associated.keys(): loss_module._has_update_associated[k] = True try: _target_names = [] for name, _ in loss_module.named_children(): # the TensorDictParams is a nn.Module instance if name.startswith("target_") and name.endswith("_params"): _target_names.append(name) if len(_target_names) == 0: raise RuntimeError( "Did not find any target parameters or buffers in the loss module." ) _source_names = ["".join(name.split("target_")) for name in _target_names] for _source in _source_names: try: getattr(loss_module, _source) except AttributeError as err: raise RuntimeError( f"Incongruent target and source parameter lists: " f"{_source} is not an attribute of the loss_module" ) from err self._target_names = _target_names self._source_names = _source_names self.loss_module = loss_module self.initialized = False self.init_() _has_update_associated = True finally: for k in loss_module._has_update_associated.keys(): loss_module._has_update_associated[k] = _has_update_associated @property def _targets(self): targets = self.__dict__.get("_targets_val", None) if targets is None: targets = self.__dict__["_targets_val"] = TensorDict( {name: getattr(self.loss_module, name) for name in self._target_names}, [], ) return targets @_targets.setter def _targets(self, targets): self.__dict__["_targets_val"] = targets @property def _sources(self): sources = self.__dict__.get("_sources_val", None) if sources is None: sources = self.__dict__["_sources_val"] = TensorDict( {name: getattr(self.loss_module, name) for name in self._source_names}, [], ) return sources @_sources.setter def _sources(self, sources): self.__dict__["_sources_val"] = sources def init_(self) -> None: if self.initialized: warnings.warn("Updated already initialized.") found_distinct = False self._distinct_and_params = {} for key, source in self._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) target = self._targets[key] # for p_source, p_target in zip(source, target): if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") self._distinct_and_params[key] = ( target.is_leaf and source.requires_grad and target.data_ptr() != source.data.data_ptr() ) found_distinct = found_distinct or self._distinct_and_params[key] target.data.copy_(source.data) if not found_distinct: raise RuntimeError( f"The target and source data are identical for all params. " "Have you created proper target parameters? " "If the loss has a ``delay_value`` kwarg, make sure to set it " "to True if it is not done by default. " f"If no target parameter is needed, do not use a target updater such as {type(self)}." ) # filter the target_ out def filter_target(key): if isinstance(key, tuple): return (filter_target(key[0]), *key[1:]) return key[7:] self._sources = self._sources.select( *[ filter_target(key) for (key, val) in self._distinct_and_params.items() if val ] ).lock_() self._targets = self._targets.select( *(key for (key, val) in self._distinct_and_params.items() if val) ).lock_() self.initialized = True def step(self) -> None: if not self.initialized: raise Exception( f"{self.__class__.__name__} must be " f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" ) for key, param in self._sources.items(): target = self._targets.get("target_{}".format(key)) if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") self._step(param, target) def _step(self, p_source: Tensor, p_target: Tensor) -> None: raise NotImplementedError def __repr__(self) -> str: string = ( f"{self.__class__.__name__}(sources={self._sources}, targets=" f"{self._targets})" ) return string
[docs]class SoftUpdate(TargetNetUpdater): r"""A soft-update class for target network update in Double DQN/DDPG. This was proposed in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf One and only one decay factor (tau or eps) must be specified. Args: loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. eps (scalar): epsilon in the update equation: .. math:: \theta_t = \theta_{t-1} * \epsilon + \theta_t * (1-\epsilon) Exclusive with ``tau``. tau (scalar): Polyak tau. It is equal to ``1-eps``, and exclusive with it. """ def __init__( self, loss_module: Union[ "DQNLoss", # noqa: F821 "DDPGLoss", # noqa: F821 "SACLoss", # noqa: F821 "REDQLoss", # noqa: F821 "TD3Loss", # noqa: F821 ], *, eps: float = None, tau: Optional[float] = None, ): if eps is None and tau is None: raise RuntimeError( "Neither eps nor tau was provided. This behavior is deprecated.", ) eps = 0.999 if (eps is None) ^ (tau is None): if eps is None: eps = 1 - tau else: raise ValueError("One and only one argument (tau or eps) can be specified.") if eps < 0.5: warnings.warn( "Found an eps value < 0.5, which is unexpected. " "You may want to use the `tau` keyword argument instead." ) if not (eps <= 1.0 and eps >= 0.0): raise ValueError( f"Got eps = {eps} when it was supposed to be between 0 and 1." ) super(SoftUpdate, self).__init__(loss_module) self.eps = eps def _step( self, p_source: Tensor | TensorDictBase, p_target: Tensor | TensorDictBase ) -> None: p_target.data.lerp_(p_source.data, 1 - self.eps)
[docs]class HardUpdate(TargetNetUpdater): """A hard-update class for target network update in Double DQN/DDPG (by contrast with soft updates). This was proposed in the original Double DQN paper: "Deep Reinforcement Learning with Double Q-learning", https://arxiv.org/abs/1509.06461. Args: loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. Keyword Args: value_network_update_interval (scalar): how often the target network should be updated. default: 1000 """ def __init__( self, loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss", "TD3Loss"], # noqa: F821 *, value_network_update_interval: float = 1000, ): super(HardUpdate, self).__init__(loss_module) self.value_network_update_interval = value_network_update_interval self.counter = 0 def _step(self, p_source: Tensor, p_target: Tensor) -> None: if self.counter == self.value_network_update_interval: p_target.data.copy_(p_source.data) def step(self) -> None: super().step() if self.counter == self.value_network_update_interval: self.counter = 0 else: self.counter += 1
[docs]class hold_out_net(_context_manager): """Context manager to hold a network out of a computational graph.""" def __init__(self, network: nn.Module) -> None: self.network = network for p in network.parameters(): self.mode = p.requires_grad break else: self.mode = True def __enter__(self) -> None: if self.mode: if is_dynamo_compiling(): self._params = TensorDict.from_module(self.network) self._params.data.to_module(self.network) else: self.network.requires_grad_(False) def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self.mode: if is_dynamo_compiling(): self._params.to_module(self.network) else: self.network.requires_grad_()
[docs]class hold_out_params(_context_manager): """Context manager to hold a list of parameters out of a computational graph.""" def __init__(self, params: Iterable[Tensor]) -> None: if isinstance(params, TensorDictBase): self.params = params.detach() else: self.params = tuple(p.detach() for p in params) def __enter__(self) -> None: return self.params def __exit__(self, exc_type, exc_val, exc_tb) -> None: pass
[docs]@torch.no_grad() def next_state_value( tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = "state_action_value", gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, **kwargs, ) -> torch.Tensor: """Computes the next state value (without gradient) to compute a target value. The target value is usually used to compute a distance loss (e.g. MSE): L = Sum[ (q_value - target_value)^2 ] The target value is computed as r + gamma ** n_steps_to_next * value_next_state If the reward is the immediate reward, n_steps_to_next=1. If N-steps rewards are used, n_steps_to_next is gathered from the input tensordict. Args: tensordict (TensorDictBase): Tensordict containing a reward and done key (and a n_steps_to_next key for n-steps rewards). operator (ProbabilisticTDModule, optional): the value function operator. Should write a 'next_val_key' key-value in the input tensordict when called. It does not need to be provided if pred_next_val is given. next_val_key (str, optional): key where the next value will be written. Default: 'state_action_value' gamma (:obj:`float`, optional): return discount rate. default: 0.99 pred_next_val (Tensor, optional): the next state value can be provided if it is not computed with the operator. Returns: a Tensor of the size of the input tensordict containing the predicted value state. """ if "steps_to_next_obs" in tensordict.keys(): steps_to_next_obs = tensordict.get("steps_to_next_obs").squeeze(-1) else: steps_to_next_obs = 1 rewards = tensordict.get(("next", "reward")).squeeze(-1) done = tensordict.get(("next", "done")).squeeze(-1) if done.all() or gamma == 0: return rewards if pred_next_val is None: next_td = step_mdp(tensordict) # next_observation -> observation next_td = next_td.select(*operator.in_keys) operator(next_td, **kwargs) pred_next_val_detach = next_td.get(next_val_key).squeeze(-1) else: pred_next_val_detach = pred_next_val.squeeze(-1) done = done.to(torch.float) target_value = (1 - done) * pred_next_val_detach rewards = rewards.to(torch.float) target_value = rewards + (gamma**steps_to_next_obs) * target_value return target_value
def _cache_values(func): """Caches the tensordict returned by a property.""" name = func.__name__ @functools.wraps(func) def new_func(self, netname=None): if is_dynamo_compiling(): if netname is not None: return func(self, netname) else: return func(self) __dict__ = self.__dict__ _cache = __dict__.setdefault("_cache", {}) attr_name = name if netname is not None: attr_name += "_" + netname if attr_name in _cache: out = _cache[attr_name] return out if netname is not None: out = func(self, netname) else: out = func(self) # TODO: decide what to do with locked tds in functional calls # if is_tensor_collection(out): # out.lock_() _cache[attr_name] = out return out return new_func def _vmap_func(module, *args, func=None, **kwargs): try: def decorated_module(*module_args_params): params = module_args_params[-1] module_args = module_args_params[:-1] with params.to_module(module): if func is None: return module(*module_args) else: return getattr(module, func)(*module_args) return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 except RuntimeError as err: if re.match( r"vmap: called random operation while in randomness error mode", str(err) ): raise RuntimeError( "Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap." ) from err def _reduce(tensor: torch.Tensor, reduction: str) -> Union[float, torch.Tensor]: """Reduces a tensor given the reduction method.""" if reduction == "none": result = tensor elif reduction == "mean": result = tensor.mean() elif reduction == "sum": result = tensor.sum() else: raise NotImplementedError(f"Unknown reduction method {reduction}") return result def _clip_value_loss( old_state_value: torch.Tensor, state_value: torch.Tensor, clip_value: torch.Tensor, target_return: torch.Tensor, loss_value: torch.Tensor, loss_critic_type: str, ): """Value clipping method for loss computation. This method computes a clipped state value from the old state value and the state value, and returns the most pessimistic value prediction between clipped and non-clipped options. It also computes the clip fraction. """ pre_clipped = state_value - old_state_value clipped = pre_clipped.clamp(-clip_value, clip_value) with torch.no_grad(): clip_fraction = (pre_clipped != clipped).to(state_value.dtype).mean() state_value_clipped = old_state_value + clipped loss_value_clipped = distance_loss( target_return, state_value_clipped, loss_function=loss_critic_type, ) # Chose the most pessimistic value prediction between clipped and non-clipped loss_value = torch.max(loss_value, loss_value_clipped) return loss_value, clip_fraction def _get_default_device(net): for p in net.parameters(): return p.device else: return getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
[docs]def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: """Groups multiple optimizers into a single one. All optimizers are expected to have the same type. """ cls = None params = [] for optimizer in optimizers: if optimizer is None: continue if cls is None: cls = type(optimizer) if cls is not type(optimizer): raise ValueError("Cannot group optimizers of different type.") params.extend(optimizer.param_groups) return cls(params)
def _sum_td_features(data: TensorDictBase) -> torch.Tensor: # Sum all features and return a tensor return data.sum(dim="feature", reduce=True) def _maybe_get_or_select(td, key_or_keys, target_shape=None): if isinstance(key_or_keys, (str, tuple)): return td.get(key_or_keys) result = td.select(*key_or_keys) if target_shape is not None and result.shape != target_shape: result.batch_size = target_shape return result def _maybe_add_or_extend_key( tensor_keys: List[NestedKey], key_or_list_of_keys: NestedKey | List[NestedKey], prefix: NestedKey = None, ): if prefix is not None: if isinstance(key_or_list_of_keys, NestedKey): tensor_keys.append(unravel_key((prefix, key_or_list_of_keys))) else: tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys]) return if isinstance(key_or_list_of_keys, NestedKey): tensor_keys.append(key_or_list_of_keys) else: tensor_keys.extend(key_or_list_of_keys)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources