Source code for torch.distributions.one_hot_categorical

import torch
from torch.distributions import constraints
from torch.distributions.categorical import Categorical
from torch.distributions.distribution import Distribution


[docs]class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by :attr:`probs` or :attr:`logits`. Samples are one-hot coded vectors of size ``probs.size(-1)``. .. note:: :attr:`probs` will be normalized to be summing to 1. See also: :func:`torch.distributions.Categorical` for specifications of :attr:`probs` and :attr:`logits`. Example:: >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 0 0 1 0 [torch.FloatTensor of size 4] Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex} support = constraints.simplex has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args) def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def probs(self): return self._categorical.probs @property def logits(self): return self._categorical.logits @property def mean(self): return self._categorical.probs @property def variance(self): return self._categorical.probs * (1 - self._categorical.probs) @property def param_shape(self): return self._categorical.param_shape
[docs] def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs one_hot = probs.new(self._extended_shape(sample_shape)).zero_() indices = self._categorical.sample(sample_shape) if indices.dim() < one_hot.dim(): indices = indices.unsqueeze(-1) return one_hot.scatter_(-1, indices, 1)
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) indices = value.max(-1)[1] return self._categorical.log_prob(indices)
[docs] def entropy(self): return self._categorical.entropy()
[docs] def enumerate_support(self): n = self.event_shape[0] values = self._new((n, n)) torch.eye(n, out=values.data) values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) return values.expand((n,) + self.batch_shape + (n,))