Shortcuts

Source code for torchrl.trainers.helpers.losses

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater


[docs]def make_target_updater( cfg: DictConfig, loss_module: LossModule # noqa: F821 ) -> TargetNetUpdater | None: """Builds a target network weight update object.""" if cfg.loss == "double": if not cfg.hard_update: target_net_updater = SoftUpdate( loss_module, eps=1 - 1 / cfg.value_network_update_interval ) else: target_net_updater = HardUpdate( loss_module, value_network_update_interval=cfg.value_network_update_interval, ) else: if cfg.hard_update: raise RuntimeError( "hard/soft-update are supposed to be used with double SAC loss. " "Consider using --loss=double or discarding the hard_update flag." ) target_net_updater = None return target_net_updater
[docs]def make_dqn_loss(model, cfg) -> tuple[DQNLoss, TargetNetUpdater | None]: """Builds the DQN loss module.""" loss_kwargs = {} if cfg.distributional: loss_class = DistributionalDQNLoss else: loss_kwargs.update({"loss_function": cfg.loss_function}) loss_class = DQNLoss if cfg.loss not in ("single", "double"): raise NotImplementedError loss_kwargs.update({"delay_value": cfg.loss == "double"}) loss_module = loss_class(model, **loss_kwargs) loss_module.make_value_estimator(gamma=cfg.gamma) target_net_updater = make_target_updater(cfg, loss_module) return loss_module, target_net_updater
@dataclass class LossConfig: """Generic Loss config struct.""" loss: str = "double" # whether double or single SAC loss should be used. Default=double hard_update: bool = False # whether soft-update should be used with double SAC loss (default) or hard updates. loss_function: str = "smooth_l1" # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). value_network_update_interval: int = 1000 # how often the target value network weights are updated (in number of updates). # If soft-updates are used, the value is translated into a moving average decay by using # the formula decay=1-1/cfg.value_network_update_interval. Default=1000 gamma: float = 0.99 # Decay factor for return computation. Default=0.99. num_q_values: int = 2 # As suggested in the original SAC paper and in https://arxiv.org/abs/1802.09477, we can # use two (or more!) different qvalue networks trained independently and choose the lowest value # predicted to predict the state action value. This can be disabled by using this flag. # REDQ uses an arbitrary number of Q-value functions to speed up learning in MF contexts. target_entropy: Any = None # Target entropy for the policy distribution. Default is None (auto calculated as the `target_entropy = -action_dim`) @dataclass class A2CLossConfig: """A2C Loss config struct.""" gamma: float = 0.99 # Decay factor for return computation. Default=0.99. entropy_coef: float = 1e-3 # Entropy factor for the A2C loss critic_coef: float = 1.0 # Critic factor for the A2C loss critic_loss_function: str = "smooth_l1" # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). @dataclass class PPOLossConfig: """PPO Loss config struct.""" loss: str = "clip" # PPO loss class, either clip or kl or base/<empty>. Default=clip # PPOLoss base parameters: gamma: float = 0.99 # Decay factor for return computation. Default=0.99. lmbda: float = 0.95 # lambda factor in GAE (using 'lambda' as attribute is prohibited in python, hence the misspelling) entropy_bonus: bool = True # whether to add an entropy term to the PPO loss. entropy_coef: float = 1e-3 # Entropy factor for the PPO loss samples_mc_entropy: int = 1 # Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula. loss_function: str = "smooth_l1" # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). critic_coef: float = 1.0 # Critic loss multiplier when computing the total loss. # ClipPPOLoss parameters: clip_epsilon: float = 0.2 # weight clipping threshold in the clipped PPO loss equation. # KLPENPPOLoss parameters: dtarg: float = 0.01 # target KL divergence. beta: float = 1.0 # initial KL divergence multiplier. increment: float = 2 # how much beta should be incremented if KL > dtarg. Valid range: increment >= 1.0 decrement: float = 0.5 # how much beta should be decremented if KL < dtarg. Valid range: decrement <= 1.0 samples_mc_kl: int = 1 # Number of samples to use for a Monte-Carlo estimate of KL if necessary

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources