RayRemoteWeightUpdater
- class torchrl.collectors.RayRemoteWeightUpdater(policy_weights: TensorDictBase, remote_collectors: list, max_interval: int = 0)[source]
A remote weight updater for synchronizing policy weights across remote workers using Ray.
The RayRemoteWeightUpdater class provides a mechanism for updating the weights of a policy across remote inference workers managed by Ray. It leverages Ray’s distributed computing capabilities to efficiently distribute policy weights to remote collectors. This class is typically used in distributed data collectors where each remote worker requires an up-to-date copy of the policy weights.
- Parameters:
policy_weights (TensorDictBase) – The current weights of the policy that need to be distributed to remote workers.
remote_collectors (List) – A list of remote collectors that will receive the updated policy weights.
max_interval (int, optional) – The maximum number of batches between weight updates for each worker. Defaults to 0, meaning weights are updated every batch.
- all_worker_ids()[source]
Returns a list of all worker identifiers (indices of remote collectors).
- _get_server_weights()[source]
Retrieves the latest weights from the server and stores them in Ray’s object store.
- _maybe_map_weights()[source]
Optionally maps server weights before distribution (no-op in this implementation).
- _sync_weights_with_worker()[source]
Synchronizes the server weights with a specific remote worker using Ray.
- _skip_update()[source]
Determines whether to skip the weight update for a specific worker based on the interval.
Note
This class assumes that the server weights can be directly applied to the remote workers without any additional processing. If your use case requires more complex weight mapping or synchronization logic, consider extending RemoteWeightUpdaterBase with a custom implementation.
See also
- register_collector(collector: DataCollectorBase)
Register a collector in the updater.
Once registered, the updater will not accept another collector.
- Parameters:
collector (DataCollectorBase) – The collector to register.