[docs]classCategorical(Distribution):r""" Creates a categorical distribution parameterized by either :attr:`probs` or :attr:`logits` (but not both). .. note:: It is equivalent to the distribution that :func:`torch.multinomial` samples from. Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``. If `probs` is 1-dimensional with length-`K`, each element is the relative probability of sampling the class at that index. If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. :attr:`probs` will return this normalized value. The `logits` argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. :attr:`logits` will return this normalized value. See also: :func:`torch.multinomial` Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor(3) Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """arg_constraints={"probs":constraints.simplex,"logits":constraints.real_vector}has_enumerate_support=Truedef__init__(self,probs=None,logits=None,validate_args=None):if(probsisNone)==(logitsisNone):raiseValueError("Either `probs` or `logits` must be specified, but not both.")ifprobsisnotNone:ifprobs.dim()<1:raiseValueError("`probs` parameter must be at least one-dimensional.")self.probs=probs/probs.sum(-1,keepdim=True)else:iflogits.dim()<1:raiseValueError("`logits` parameter must be at least one-dimensional.")# Normalizeself.logits=logits-logits.logsumexp(dim=-1,keepdim=True)self._param=self.probsifprobsisnotNoneelseself.logitsself._num_events=self._param.size()[-1]batch_shape=(self._param.size()[:-1]ifself._param.ndimension()>1elsetorch.Size())super().__init__(batch_shape,validate_args=validate_args)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.