Shortcuts

Trainer

class torchrl.trainers.Trainer(*args, **kwargs)[source]

A generic Trainer class.

A trainer is responsible for collecting data and training the model. To keep the class as versatile as possible, Trainer does not construct any of its specific operations: they all must be hooked at specific points in the training loop.

To build a Trainer, one needs an iterable data source (a collector), a loss module and an optimizer.

Parameters:
  • collector (Sequence[TensorDictBase]) – An iterable returning batches of data in a TensorDict form of shape [batch x time steps].

  • total_frames (int) – Total number of frames to be collected during training.

  • loss_module (LossModule) – A module that reads TensorDict batches (possibly sampled from a replay buffer) and return a loss TensorDict where every key points to a different loss component.

  • optimizer (optim.Optimizer) – An optimizer that trains the parameters of the model.

  • logger (Logger, optional) – a Logger that will handle the logging.

  • optim_steps_per_batch (int) – number of optimization steps per collection of data. An trainer works as follows: a main loop collects batches of data (epoch loop), and a sub-loop (training loop) performs model updates in between two collections of data.

  • clip_grad_norm (bool, optional) – If True, the gradients will be clipped based on the total norm of the model parameters. If False, all the partial derivatives will be clamped to (-clip_norm, clip_norm). Default is True.

  • clip_norm (Number, optional) – value to be used for clipping gradients. Default is None (no clip norm).

  • progress_bar (bool, optional) – If True, a progress bar will be displayed using tqdm. If tqdm is not installed, this option won’t have any effect. Default is True

  • seed (int, optional) – Seed to be used for the collector, pytorch and numpy. Default is None.

  • save_trainer_interval (int, optional) – How often the trainer should be saved to disk, in frame count. Default is 10000.

  • log_interval (int, optional) – How often the values should be logged, in frame count. Default is 10000.

  • save_trainer_file (path, optional) – path where to save the trainer. Default is None (no saving)

load_from_file(file: Union[str, Path], **kwargs) Trainer[source]

Loads a file and its state-dict in the trainer.

Keyword arguments are passed to the load() function.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources