Trainer¶
- class torchrl.trainers.Trainer(*args, **kwargs)[source]¶
A generic Trainer class.
A trainer is responsible for collecting data and training the model. To keep the class as versatile as possible, Trainer does not construct any of its specific operations: they all must be hooked at specific points in the training loop.
To build a Trainer, one needs an iterable data source (a
collector
), a loss module and an optimizer.- Parameters:
collector (Sequence[TensorDictBase]) – An iterable returning batches of data in a TensorDict form of shape [batch x time steps].
total_frames (int) – Total number of frames to be collected during training.
loss_module (LossModule) – A module that reads TensorDict batches (possibly sampled from a replay buffer) and return a loss TensorDict where every key points to a different loss component.
optimizer (optim.Optimizer) – An optimizer that trains the parameters of the model.
logger (Logger, optional) – a Logger that will handle the logging.
optim_steps_per_batch (int) – number of optimization steps per collection of data. An trainer works as follows: a main loop collects batches of data (epoch loop), and a sub-loop (training loop) performs model updates in between two collections of data.
clip_grad_norm (bool, optional) – If True, the gradients will be clipped based on the total norm of the model parameters. If False, all the partial derivatives will be clamped to (-clip_norm, clip_norm). Default is
True
.clip_norm (Number, optional) – value to be used for clipping gradients. Default is None (no clip norm).
progress_bar (bool, optional) – If True, a progress bar will be displayed using tqdm. If tqdm is not installed, this option won’t have any effect. Default is
True
seed (int, optional) – Seed to be used for the collector, pytorch and numpy. Default is
None
.save_trainer_interval (int, optional) – How often the trainer should be saved to disk, in frame count. Default is 10000.
log_interval (int, optional) – How often the values should be logged, in frame count. Default is 10000.
save_trainer_file (path, optional) – path where to save the trainer. Default is None (no saving)