OneHotCategorical¶
- class torchrl.modules.OneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]¶
One-hot categorical distribution.
This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings of the discrete tensors.
- Parameters:
logits (torch.Tensor) – event log probabilities (unnormalized)
probs (torch.Tensor) – event probabilities
grad_method (ReparamGradientStrategy, optional) –
strategy to gather reparameterized samples.
ReparamGradientStrategy.PassThrough
will compute the sample gradientsby using the softmax valued log-probability as a proxy to the sample gradients.
ReparamGradientStrategy.RelaxedOneHot
will usetorch.distributions.RelaxedOneHot
to sample from the distribution.
Examples
>>> torch.manual_seed(0) >>> logits = torch.randn(4) >>> dist = OneHotCategorical(logits=logits) >>> print(dist.rsample((3,))) tensor([[1., 0., 0., 0.], [0., 0., 0., 1.], [1., 0., 0., 0.]])
- log_prob(value: Tensor) Tensor [source]¶
Returns the log of the probability density/mass function evaluated at value.
- Parameters:
value (Tensor) –