AdditiveGaussianWrapper¶
- class torchrl.modules.AdditiveGaussianWrapper(*args, **kwargs)[source]¶
Additive Gaussian PO wrapper.
- Parameters:
policy (TensorDictModule) – a policy.
- Keyword Arguments:
sigma_init (scalar, optional) – initial epsilon value. default: 1.0
sigma_end (scalar, optional) – final epsilon value. default: 0.1
annealing_num_steps (int, optional) – number of steps it will take for sigma to reach the
sigma_end
value.mean (
float
, optional) – mean of each output element’s normal distribution.std (
float
, optional) – standard deviation of each output element’s normal distribution.action_key (NestedKey, optional) – if the policy module has more than one output key, its output spec will be of type Composite. One needs to know where to find the action spec. Default is “action”.
spec (TensorSpec, optional) – if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy.
safe (boolean, optional) – if False, the TensorSpec can be None. If it is set to False but the spec is passed, the projection will still happen. Default is True.
device (torch.device, optional) – the device where the buffers have to be stored.
Note
Once an environment has been wrapped in
AdditiveGaussianWrapper
, it is crucial to incorporate a call tostep()
in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted!- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.