Shortcuts

TanhNormal

class torchrl.modules.TanhNormal(loc: torch.Tensor, scale: torch.Tensor, upscale: Union[torch.Tensor, Number] = 5.0, low: Union[torch.Tensor, Number] = - 1.0, high: Union[torch.Tensor, Number] = 1.0, event_dims: int | None = None, tanh_loc: bool = False, **kwargs)[source]

Implements a TanhNormal distribution with location scaling.

Location scaling prevents the location to be “too far” from 0 when a TanhTransform is applied, but ultimately leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). In practice, with location scaling the location is computed according to

\[loc = tanh(loc / upscale) * upscale.\]
Parameters:
  • loc (torch.Tensor) – normal distribution location parameter

  • scale (torch.Tensor) – normal distribution sigma parameter (squared root of variance)

  • upscale (torch.Tensor or number) –

    ‘a’ scaling factor in the formula:

    \[loc = tanh(loc / upscale) * upscale.\]

  • min (torch.Tensor or number, optional) – minimum value of the distribution. Default is -1.0;

  • max (torch.Tensor or number, optional) – maximum value of the distribution. Default is 1.0;

  • event_dims (int, optional) – number of dimensions describing the action. Default is 1. Setting event_dims to 0 will result in a log-probability that has the same shape as the input, 1 will reduce (sum over) the last dimension, 2 the last two etc.

  • tanh_loc (bool, optional) – if True, the above formula is used for the location scaling, otherwise the raw value is kept. Default is False;

get_mode()[source]

Computes an estimation of the mode using the Adam optimizer.

property mean

Returns the mean of the distribution.

property mode

Returns the mode of the distribution.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources