VanillaLocalWeightUpdater
- class torchrl.collectors.VanillaLocalWeightUpdater(weight_getter: Callable[[], TensorDictBase], policy_weights: TensorDictBase)[source]
A simple implementation of LocalWeightUpdaterBase for updating local policy weights.
The VanillaLocalWeightUpdater class provides a basic mechanism for updating the weights of a local policy by directly fetching them from a specified source. It is typically used in scenarios where the weight update logic is straightforward and does not require any complex mapping or transformation.
This class is used by default in the SyncDataCollector when no custom local weights updater is provided.
- Parameters:
weight_getter (Callable[[], TensorDictBase]) – A callable that returns the latest policy weights from the server or another source.
policy_weights (TensorDictBase) – The current weights of the local policy that need to be updated.
- _get_server_weights()[source]
Retrieves the latest weights from the specified source.
- _get_local_weights()[source]
Accesses the current local policy weights.
- _map_weights()[source]
Directly maps server weights to local weights without transformation.
- _maybe_map_weights()[source]
Optionally maps server weights to local weights (no-op in this implementation).
- _update_local_weights()[source]
Updates the local policy weights with the mapped weights.
Note
This class assumes that the server weights can be directly applied to the local policy without any additional processing. If your use case requires more complex weight mapping, consider extending LocalWeightUpdaterBase with a custom implementation.
See also
- register_collector(collector: DataCollectorBase)
Register a collector in the updater.
Once registered, the updater will not accept another collector.
- Parameters:
collector (DataCollectorBase) – The collector to register.