TanhNormal¶
- class torchrl.modules.TanhNormal(loc: Tensor, scale: Tensor, upscale: Union[Tensor, Number] = 5.0, min: Union[Tensor, Number] = - 1.0, max: Union[Tensor, Number] = 1.0, event_dims: int = 1, tanh_loc: bool = False)[source]¶
Implements a TanhNormal distribution with location scaling.
Location scaling prevents the location to be “too far” from 0 when a
TanhTransform
is applied, but ultimately leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). In practice, with location scaling the location is computed according to\[loc = tanh(loc / upscale) * upscale.\]- Parameters:
loc (torch.Tensor) – normal distribution location parameter
scale (torch.Tensor) – normal distribution sigma parameter (squared root of variance)
upscale (torch.Tensor or number) –
‘a’ scaling factor in the formula:
\[loc = tanh(loc / upscale) * upscale.\]min (torch.Tensor or number, optional) – minimum value of the distribution. Default is -1.0;
max (torch.Tensor or number, optional) – maximum value of the distribution. Default is 1.0;
event_dims (int, optional) – number of dimensions describing the action. Default is 1;
tanh_loc (bool, optional) – if
True
, the above formula is used for the location scaling, otherwise the raw value is kept. Default isFalse
;
- property mode¶
Returns the mode of the distribution.