UpdateWeights¶
- class torchrl.trainers.UpdateWeights(collector: DataCollectorBase, update_weights_interval: int)[source]¶
A collector weights update hook class.
This hook must be used whenever the collector policy weights sit on a different device than the policy weights being trained by the Trainer. In that case, those weights must be synced across devices at regular intervals. If the devices match, this will result in a no-op.
- Parameters:
collector (DataCollectorBase) – A data collector where the policy weights must be synced.
update_weights_interval (int) – Interval (in terms of number of batches collected) where the sync must take place.
Examples
>>> update_weights = UpdateWeights(trainer.collector, T) >>> trainer.register_op("post_steps", update_weights)
- register(trainer: Trainer, name: str = 'update_weights')[source]¶
Registers the hook in the trainer at a default location.
- Parameters:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
Note
To register the hook at another location than the default, use
register_op()
.