ReplayBufferTrainer¶
- class torchrl.trainers.ReplayBufferTrainer(replay_buffer: TensorDictReplayBuffer, batch_size: Optional[int] = None, memmap: bool = False, device: Union[device, str, int] = 'cpu', flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None)[source]¶
Replay buffer hook provider.
- Parameters:
replay_buffer (TensorDictReplayBuffer) – replay buffer to be used.
batch_size (int, optional) – batch size when sampling data from the latest collection or from the replay buffer. If none is provided, the replay buffer batch-size will be used (preferred option for unchanged batch-sizes).
memmap (bool, optional) – if
True
, a memmap tensordict is created. Default isFalse
.device (device, optional) – device where the samples must be placed. Default is
cpu
.flatten_tensordicts (bool, optional) – if
True
, the tensordicts will be flattened (or equivalently masked with the valid mask obtained from the collector) before being passed to the replay buffer. Otherwise, no transform will be achieved other than padding (seemax_dims
arg below). Defaults toFalse
.max_dims (sequence of int, optional) – if
flatten_tensordicts
is set to False, this will be a list of the length of the batch_size of the provided tensordicts that represent the maximum size of each. If provided, this list of sizes will be used to pad the tensordict and make their shape match before they are passed to the replay buffer. If there is no maximum value, a -1 value should be provided.
Examples
>>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N) >>> trainer.register_op("batch_process", rb_trainer.extend) >>> trainer.register_op("process_optim_batch", rb_trainer.sample) >>> trainer.register_op("post_loss", rb_trainer.update_priority)
- register(trainer: Trainer, name: str = 'replay_buffer')[source]¶
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
Note
To register the hook at another location than the default, use
register_op()
.