Shortcuts

Source code for torchrl.trainers.helpers.trainers

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass
from warnings import warn

import torch
from tensordict.nn import TensorDictModule, TensorDictModuleWrapper
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.collectors.collectors import DataCollectorBase
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
from torchrl.envs.common import EnvBase
from torchrl.envs.utils import ExplorationType
from torchrl.modules import reset_noise
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.record.loggers import Logger
from torchrl.trainers.trainers import (
    BatchSubSampler,
    ClearCudaCache,
    CountFramesLog,
    LogScalar,
    LogValidationReward,
    ReplayBufferTrainer,
    RewardNormalizer,
    SelectKeys,
    Trainer,
    UpdateWeights,
)

OPTIMIZERS = {
    "adam": optim.Adam,
    "sgd": optim.SGD,
    "adamax": optim.Adamax,
}


@dataclass
class TrainerConfig:
    """Trainer config struct."""

    optim_steps_per_batch: int = 500
    # Number of optimization steps in between two collection of data. See frames_per_batch below.
    optimizer: str = "adam"
    # Optimizer to be used.
    lr_scheduler: str = "cosine"
    # LR scheduler.
    selected_keys: list | None = None
    # a list of strings that indicate the data that should be kept from the data collector. Since storing and
    # retrieving information from the replay buffer does not come for free, limiting the amount of data
    # passed to it can improve the algorithm performance.
    batch_size: int = 256
    # batch size of the TensorDict retrieved from the replay buffer. Default=256.
    log_interval: int = 10000
    # logging interval, in terms of optimization steps. Default=10000.
    lr: float = 3e-4
    # Learning rate used for the optimizer. Default=3e-4.
    weight_decay: float = 0.0
    # Weight-decay to be used with the optimizer. Default=0.0.
    clip_norm: float = 1000.0
    # value at which the total gradient norm / single derivative should be clipped. Default=1000.0
    clip_grad_norm: bool = False
    # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold.
    normalize_rewards_online: bool = False
    # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module.
    normalize_rewards_online_scale: float = 1.0
    # Final scale of the normalized rewards.
    normalize_rewards_online_decay: float = 0.9999
    # Decay of the reward moving averaging
    sub_traj_len: int = -1
    # length of the trajectories that sub-samples must have in online settings.


[docs]def make_trainer( collector: DataCollectorBase, loss_module: LossModule, recorder: EnvBase | None = None, target_net_updater: TargetNetUpdater | None = None, policy_exploration: None | (TensorDictModuleWrapper | TensorDictModule) = None, replay_buffer: ReplayBuffer | None = None, logger: Logger | None = None, cfg: DictConfig = None, # noqa: F821 ) -> Trainer: """Creates a Trainer instance given its constituents. Args: collector (DataCollectorBase): A data collector to be used to collect data. loss_module (LossModule): A TorchRL loss module recorder (EnvBase, optional): a recorder environment. If None, the trainer will train the policy without testing it. target_net_updater (TargetNetUpdater, optional): A target network update object. policy_exploration (TDModule or TensorDictModuleWrapper, optional): a policy to be used for recording and exploration updates (should be synced with the learnt policy). replay_buffer (ReplayBuffer, optional): a replay buffer to be used to collect data. logger (Logger, optional): a Logger to be used for logging. cfg (DictConfig, optional): a DictConfig containing the arguments of the script. If None, the default arguments are used. Returns: A trainer built with the input objects. The optimizer is built by this helper function using the cfg provided. Examples: >>> import torch >>> import tempfile >>> from torchrl.trainers.loggers import TensorboardLogger >>> from torchrl.trainers import Trainer >>> from torchrl.envs import EnvCreator >>> from torchrl.collectors.collectors import SyncDataCollector >>> from torchrl.data import TensorDictReplayBuffer >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper >>> from torchrl.objectives.common import LossModule >>> from torchrl.objectives.utils import TargetNetUpdater >>> from torchrl.objectives import DDPGLoss >>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0")) >>> env_proof = env_maker() >>> obs_spec = env_proof.observation_spec >>> action_spec = env_proof.action_spec >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1]) >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"]) >>> collector = SyncDataCollector(env_maker, policy, total_frames=100) >>> loss_module = DDPGLoss(policy, value, gamma=0.99) >>> recorder = env_proof >>> target_net_updater = None >>> policy_exploration = EGreedyWrapper(policy) >>> replay_buffer = TensorDictReplayBuffer() >>> dir = tempfile.gettempdir() >>> logger = TensorboardLogger(exp_name=dir) >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration, ... replay_buffer, logger) >>> print(trainer) """ if cfg is None: warn( "Getting default cfg for the trainer. " "This should be only used for debugging." ) cfg = TrainerConfig() cfg.frame_skip = 1 cfg.total_frames = 1000 cfg.record_frames = 10 cfg.record_interval = 10 optimizer_kwargs = {} if cfg.optimizer != "adam" else {"betas": (0.0, 0.9)} optimizer = OPTIMIZERS[cfg.optimizer]( loss_module.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, **optimizer_kwargs, ) device = next(loss_module.parameters()).device if cfg.lr_scheduler == "cosine": optim_scheduler = CosineAnnealingLR( optimizer, T_max=int( cfg.total_frames / cfg.frames_per_batch * cfg.optim_steps_per_batch ), ) elif cfg.lr_scheduler == "": optim_scheduler = None else: raise NotImplementedError(f"lr scheduler {cfg.lr_scheduler}") if VERBOSE: torchrl_logger.info( f"collector = {collector}; \n" f"loss_module = {loss_module}; \n" f"recorder = {recorder}; \n" f"target_net_updater = {target_net_updater}; \n" f"policy_exploration = {policy_exploration}; \n" f"replay_buffer = {replay_buffer}; \n" f"logger = {logger}; \n" f"cfg = {cfg}; \n" ) if logger is not None: # log hyperparams logger.log_hparams(cfg) trainer = Trainer( collector=collector, frame_skip=cfg.frame_skip, total_frames=cfg.total_frames * cfg.frame_skip, loss_module=loss_module, optimizer=optimizer, logger=logger, optim_steps_per_batch=cfg.optim_steps_per_batch, clip_grad_norm=cfg.clip_grad_norm, clip_norm=cfg.clip_norm, ) if torch.cuda.device_count() > 0: trainer.register_op("pre_optim_steps", ClearCudaCache(1)) if hasattr(cfg, "noisy") and cfg.noisy: trainer.register_op("pre_optim_steps", lambda: loss_module.apply(reset_noise)) if cfg.selected_keys: trainer.register_op("batch_process", SelectKeys(cfg.selected_keys)) trainer.register_op("batch_process", lambda batch: batch.cpu()) if replay_buffer is not None: # replay buffer is used 2 or 3 times: to register data, to sample # data and to update priorities rb_trainer = ReplayBufferTrainer( replay_buffer, cfg.batch_size, flatten_tensordicts=False, memmap=False, device=device, ) trainer.register_op("batch_process", rb_trainer.extend) trainer.register_op("process_optim_batch", rb_trainer.sample) trainer.register_op("post_loss", rb_trainer.update_priority) else: # trainer.register_op("batch_process", mask_batch) trainer.register_op( "process_optim_batch", BatchSubSampler(batch_size=cfg.batch_size, sub_traj_len=cfg.sub_traj_len), ) trainer.register_op("process_optim_batch", lambda batch: batch.to(device)) if optim_scheduler is not None: trainer.register_op("post_optim", optim_scheduler.step) if target_net_updater is not None: trainer.register_op("post_optim", target_net_updater.step) if cfg.normalize_rewards_online: # if used the running statistics of the rewards are computed and the # rewards used for training will be normalized based on these. reward_normalizer = RewardNormalizer( scale=cfg.normalize_rewards_online_scale, decay=cfg.normalize_rewards_online_decay, ) trainer.register_op("batch_process", reward_normalizer.update_reward_stats) trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward) if policy_exploration is not None and hasattr(policy_exploration, "step"): trainer.register_op( "post_steps", policy_exploration.step, frames=cfg.frames_per_batch ) trainer.register_op( "post_steps_log", lambda *cfg: {"lr": optimizer.param_groups[0]["lr"]} ) if recorder is not None: # create recorder object recorder_obj = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, environment=recorder, record_interval=cfg.record_interval, log_keys=cfg.recorder_log_keys, ) # register recorder trainer.register_op( "post_steps_log", recorder_obj, ) # call recorder - could be removed recorder_obj(None) # create explorative recorder - could be optional recorder_obj_explore = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, environment=recorder, record_interval=cfg.record_interval, exploration_type=ExplorationType.RANDOM, suffix="exploration", out_keys={("next", "reward"): "r_evaluation_exploration"}, ) # register recorder trainer.register_op( "post_steps_log", recorder_obj_explore, ) # call recorder - could be removed recorder_obj_explore(None) trainer.register_op( "post_steps", UpdateWeights(collector, update_weights_interval=1) ) trainer.register_op("pre_steps_log", LogScalar()) trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip)) return trainer

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources