torchrl.trainers package¶
The trainer package provides utilities to write re-usable training scripts. The core idea is to use a trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner loop the optimization steps. We believe this fits multiple RL training schemes, such as on-policy, off-policy, model-based and model-free solutions, offline RL and others. More particular cases, such as meta-RL algorithms may have training schemes that differ substentially.
The trainer.train()
method can be sketched as follows:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop: "batch_process"
, "pre_optim_steps"
,
"process_optim_batch"
, "post_loss"
, "post_steps"
, "post_optim"
, "pre_steps_log"
,
"post_steps_log"
, "post_optim_log"
and "optimizer"
. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: data processing ("batch_process"
and "process_optim_batch"
),
logging ("pre_steps_log"
, "post_optim_log"
and "post_steps_log"
) and operations hook
("pre_optim_steps"
, "post_loss"
, "post_optim"
and "post_steps"
).
Data processing hooks update a tensordict of data. Hooks
__call__
method should accept aTensorDict
object as input and update it given some strategy. Examples of such hooks include Replay Buffer extension (ReplayBufferTrainer.extend
), data normalization (including normalization constants update), data subsampling (:class:~torchrl.trainers.BatchSubSampler
) and such.Logging hooks take a batch of data presented as a
TensorDict
and write in the logger some information retrieved from that data. Examples include theLogValidationReward
hook, the reward logger (LogScalar
) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key"log_pbar"
is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log.Operation hooks are hooks that execute specific operations over the models, data collectors, target network updates and such. For instance, syncing the weights of the collectors using
UpdateWeights
or update the priority of the replay buffer usingReplayBufferTrainer.update_priority
are examples of operation hooks. They are data-independent (they do not require aTensorDict
input), they are just supposed to be executed once at every iteration (or every N iterations).
The hooks provided by TorchRL usually inherit from a common abstract class TrainerHookBase
,
and all implement three base methods: a state_dict
and load_state_dict
method for
checkpointing and a register
method that registers the hook at the default value in the
trainer. This method takes a trainer and a module name as input. For instance, the following logging
hook is executed every 10 calls to "post_optim_log"
:
>>> class LoggingHook(TrainerHookBase):
... def __init__(self):
... self.counter = 0
...
... def register(self, trainer, name):
... trainer.register_module(self, "logging_hook")
... trainer.register_op("post_optim_log", self)
...
... def save_dict(self):
... return {"counter": self.counter}
...
... def load_state_dict(self, state_dict):
... self.counter = state_dict["counter"]
...
... def __call__(self, batch):
... if self.counter % 10 == 0:
... self.counter += 1
... out = {"some_value": batch["some_value"].item(), "log_pbar": False}
... else:
... out = None
... self.counter += 1
... return out
Checkpointing¶
The trainer class and hooks support checkpointing, which can be achieved either
using the torchsnapshot backend or
the regular torch backend. This can be controlled via the global variable CKPT_BACKEND
:
$ CKPT_BACKEND=torchsnapshot python script.py
CKPT_BACKEND
defaults to torch
. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.
When building a trainer, one can provide a path where the checkpoints are to
be written. With the torchsnapshot
backend, a directory path is expected,
whereas the torch
backend expects a file path (typically a .pt
file).
>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
... collector=collector,
... total_frames=total_frames,
... frame_skip=frame_skip,
... loss_module=loss_module,
... optimizer=optimizer,
... save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)
The Trainer.train()
method can be used to execute the above loop with all of
its hooks, although using the Trainer
class for its checkpointing capability
only is also a perfectly valid use.
Trainer and hooks¶
|
Data subsampler for online RL sota-implementations. |
|
Clears cuda cache at a given interval. |
|
A frame counter hook. |
|
Reward logger hook. |
|
Add an optimizer for one or more loss components. |
|
Recorder hook for |
|
Replay buffer hook provider. |
|
Reward normalizer hook. |
|
Selects keys in a TensorDict batch. |
|
A generic Trainer class. |
An abstract hooking class for torchrl Trainer class. |
|
|
A collector weights update hook class. |
Builders¶
|
Returns a data collector for off-policy sota-implementations. |
|
Makes a collector in on-policy settings. |
|
Builds the DQN loss module. |
|
Builds a replay buffer using the config built from ReplayArgsConfig. |
|
Builds a target network weight update object. |
|
Creates a Trainer instance given its constituents. |
|
Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. |
|
Runs asynchronous collectors, each running synchronous environments. |
|
Runs synchronous collectors, each running synchronous environments. |
|
Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. |
Utils¶
Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. |
|
|
Gathers stas (loc and scale) from an environment using random rollouts. |
Loggers¶
|
A template for loggers. |
|
A minimal-dependecy CSV logger. |
|
Wrapper for the mlflow logger. |
|
Wrapper for the Tensoarboard logger. |
|
Wrapper for the wandb logger. |
|
Get a logger instance of the provided logger_type. |
|
Generates an ID (str) for the described experiment using UUID and current date. |
Recording utils¶
Recording utils are detailed here.
|
Video Recorder transform. |
|
TensorDict recorder. |
|
A transform to call render on the parent environment and register the pixel observation in the tensordict. |