Shortcuts

TruncatedNormal

class torchrl.modules.TruncatedNormal(loc: Tensor, scale: Tensor, upscale: Union[Tensor, float] = 5.0, low: Union[Tensor, float] = - 1.0, high: Union[Tensor, float] = 1.0, tanh_loc: bool = False, **kwargs)[source]

Implements a Truncated Normal distribution with location scaling.

Location scaling prevents the location to be “too far” from 0, which ultimately leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). In practice, the location is computed according to

\[loc = tanh(loc / upscale) * upscale.\]

This behaviour can be disabled by switching off the tanh_loc parameter (see below).

Parameters:
  • loc (torch.Tensor) – normal distribution location parameter

  • scale (torch.Tensor) – normal distribution sigma parameter (squared root of variance)

  • upscale (torch.Tensor or number, optional) –

    ‘a’ scaling factor in the formula:

    \[loc = tanh(loc / upscale) * upscale.\]

    Default is 5.0

  • min (torch.Tensor or number, optional) – minimum value of the distribution. Default = -1.0;

  • max (torch.Tensor or number, optional) – maximum value of the distribution. Default = 1.0;

  • tanh_loc (bool, optional) – if True, the above formula is used for the location scaling, otherwise the raw value is kept. Default is False;

log_prob(value, **kwargs)[source]

Returns the log of the probability density/mass function evaluated at value.

Parameters:

value (Tensor) –

property mode

Returns the mode of the distribution.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources