MultiProcessedRemoteWeightUpdate
- class torchrl.collectors.MultiProcessedRemoteWeightUpdate(get_server_weights: Callable[[], TensorDictBase] | None, policy_weights: dict[torch.device, TensorDictBase])[source]
A remote weight updater for synchronizing policy weights across multiple processes or devices.
The MultiProcessedRemoteWeightUpdate class provides a mechanism for updating the weights of a policy across multiple inference workers in a multiprocessed environment. It is designed to handle the distribution of weights from a central server to various devices or processes that are running the policy. This class is typically used in multiprocessed data collectors where each process or device requires an up-to-date copy of the policy weights.
- Parameters:
get_server_weights (Callable[[], TensorDictBase] | None) – A callable that retrieves the latest policy weights from the server or another centralized source.
policy_weights (Dict[torch.device, TensorDictBase]) – A dictionary mapping each device or process to its current policy weights, which will be updated.
- all_worker_ids()[source]
Returns a list of all worker identifiers (devices or processes).
- _sync_weights_with_worker()[source]
Synchronizes the server weights with a specific worker.
- _get_server_weights()[source]
Retrieves the latest weights from the server.
- _maybe_map_weights()[source]
Optionally maps server weights before distribution (no-op in this implementation).
Note
This class assumes that the server weights can be directly applied to the workers without any additional processing. If your use case requires more complex weight mapping or synchronization logic, consider extending RemoteWeightUpdaterBase with a custom implementation.
See also
- register_collector(collector: DataCollectorBase)
Register a collector in the updater.
Once registered, the updater will not accept another collector.
- Parameters:
collector (DataCollectorBase) – The collector to register.