Shortcuts

LocalWeightUpdaterBase

class torchrl.collectors.LocalWeightUpdaterBase[source]

A base class for updating local policy weights from a server.

This class provides an interface for downloading and updating the weights of a policy on a local inference worker. The update process is decentralized, meaning the inference worker is responsible for fetching the weights from the server.

To extend this class, implement the following abstract methods:

  • _get_server_weights: Define how to retrieve the weights from the server.

  • _get_local_weights: Define how to access the current local weights.

  • _maybe_map_weights: Optionally transform the server weights to match the local weights.

Variables:
  • policy (Policy, optional) – The policy whose weights are to be updated.

  • get_weights_from_policy (Callable, optional) – A function to extract weights from the policy.

  • get_weights_from_server (Callable, optional) – A function to fetch weights from the server.

  • weight_map_fn (Callable, optional) – A function to map server weights to local weights.

  • cache_policy_weights (bool) – Whether to cache the policy weights locally.

update_weights()[source]

Updates the local weights with the server weights.

See also

RemoteWeightsUpdaterBase and update_policy_weights_().

register_collector(collector: DataCollectorBase)[source]

Register a collector in the updater.

Once registered, the updater will not accept another collector.

Parameters:

collector (DataCollectorBase) – The collector to register.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources