Shortcuts

UpdateWeights

class torchrl.trainers.UpdateWeights(collector: DataCollectorBase, update_weights_interval: int)[source]

A collector weights update hook class.

This hook must be used whenever the collector policy weights sit on a different device than the policy weights being trained by the Trainer. In that case, those weights must be synced across devices at regular intervals. If the devices match, this will result in a no-op.

Parameters:
  • collector (DataCollectorBase) – A data collector where the policy weights must be synced.

  • update_weights_interval (int) – Interval (in terms of number of batches collected) where the sync must take place.

Examples

>>> update_weights = UpdateWeights(trainer.collector, T)
>>> trainer.register_op("post_steps", update_weights)
register(trainer: Trainer, name: str = 'update_weights')[source]

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Note

To register the hook at another location than the default, use register_op().

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources