LocalWeightUpdaterBase
- class torchrl.collectors.LocalWeightUpdaterBase[source]
A base class for updating local policy weights from a server.
This class provides an interface for downloading and updating the weights of a policy on a local inference worker. The update process is decentralized, meaning the inference worker is responsible for fetching the weights from the server.
To extend this class, implement the following abstract methods:
_get_server_weights: Define how to retrieve the weights from the server.
_get_local_weights: Define how to access the current local weights.
_maybe_map_weights: Optionally transform the server weights to match the local weights.
- Variables:
policy (Policy, optional) – The policy whose weights are to be updated.
get_weights_from_policy (Callable, optional) – A function to extract weights from the policy.
get_weights_from_server (Callable, optional) – A function to fetch weights from the server.
weight_map_fn (Callable, optional) – A function to map server weights to local weights.
cache_policy_weights (bool) – Whether to cache the policy weights locally.
- update_weights()[source]
Updates the local weights with the server weights.
See also
RemoteWeightsUpdaterBase
andupdate_policy_weights_()
.- register_collector(collector: DataCollectorBase)[source]
Register a collector in the updater.
Once registered, the updater will not accept another collector.
- Parameters:
collector (DataCollectorBase) – The collector to register.