Shortcuts

OneHotCategorical

class torchrl.modules.OneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]

One-hot categorical distribution.

This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings of the discrete tensors.

Parameters:
  • logits (torch.Tensor) – event log probabilities (unnormalized)

  • probs (torch.Tensor) – event probabilities

  • grad_method (ReparamGradientStrategy, optional) –

    strategy to gather reparameterized samples. ReparamGradientStrategy.PassThrough will compute the sample gradients

    by using the softmax valued log-probability as a proxy to the sample gradients.

    ReparamGradientStrategy.RelaxedOneHot will use torch.distributions.RelaxedOneHot to sample from the distribution.

Examples

>>> torch.manual_seed(0)
>>> logits = torch.randn(4)
>>> dist = OneHotCategorical(logits=logits)
>>> print(dist.rsample((3,)))
tensor([[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])
log_prob(value: Tensor) Tensor[source]

Returns the log of the probability density/mass function evaluated at value.

Parameters:

value (Tensor) –

property mode: Tensor

Returns the mode of the distribution.

rsample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

sample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources