[docs]classOneHotCategorical(Distribution):r""" Creates a one-hot categorical distribution parameterized by :attr:`probs` or :attr:`logits`. Samples are one-hot coded vectors of size ``probs.size(-1)``. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. :attr:`probs` will return this normalized value. The `logits` argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. :attr:`logits` will return this normalized value. See also: :func:`torch.distributions.Categorical` for specifications of :attr:`probs` and :attr:`logits`. Example:: >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.]) Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """arg_constraints={'probs':constraints.simplex,'logits':constraints.real_vector}support=constraints.one_hothas_enumerate_support=Truedef__init__(self,probs=None,logits=None,validate_args=None):self._categorical=Categorical(probs,logits)batch_shape=self._categorical.batch_shapeevent_shape=self._categorical.param_shape[-1:]super(OneHotCategorical,self).__init__(batch_shape,event_shape,validate_args=validate_args)
classOneHotCategoricalStraightThrough(OneHotCategorical):r""" Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- through gradient estimator from [1]. [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (Bengio et al, 2013) """has_rsample=Truedefrsample(self,sample_shape=torch.Size()):samples=self.sample(sample_shape)probs=self._categorical.probs# cached via @lazy_propertyreturnsamples+(probs-probs.detach())
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.