[docs]classMultinomial(Distribution):r""" Creates a Multinomial distribution parameterized by :attr:`total_count` and either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of :attr:`probs` indexes over categories. All other dimensions index over batches. Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is called (see example below) .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. :attr:`probs` will return this normalized value. The `logits` argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. :attr:`logits` will return this normalized value. - :meth:`sample` requires a single shared `total_count` for all parameters and samples. - :meth:`log_prob` allows different `total_count` for each parameter and sample. Example:: >>> # xdoctest: +SKIP("FIXME: found invalid values") >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) >>> x = m.sample() # equal probability of 0, 1, 2, 3 tensor([ 21., 24., 30., 25.]) >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) tensor([-4.1338]) Args: total_count (int): number of trials probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """arg_constraints={"probs":constraints.simplex,"logits":constraints.real_vector}total_count:int@propertydefmean(self):returnself.probs*self.total_count@propertydefvariance(self):returnself.total_count*self.probs*(1-self.probs)def__init__(self,total_count=1,probs=None,logits=None,validate_args=None):ifnotisinstance(total_count,int):raiseNotImplementedError("inhomogeneous total_count is not supported")self.total_count=total_countself._categorical=Categorical(probs=probs,logits=logits)self._binomial=Binomial(total_count=total_count,probs=self.probs)batch_shape=self._categorical.batch_shapeevent_shape=self._categorical.param_shape[-1:]super().__init__(batch_shape,event_shape,validate_args=validate_args)
[docs]defsample(self,sample_shape=torch.Size()):sample_shape=torch.Size(sample_shape)samples=self._categorical.sample(torch.Size((self.total_count,))+sample_shape)# samples.shape is (total_count, sample_shape, batch_shape), need to change it to# (sample_shape, batch_shape, total_count)shifted_idx=list(range(samples.dim()))shifted_idx.append(shifted_idx.pop(0))samples=samples.permute(*shifted_idx)counts=samples.new(self._extended_shape(sample_shape)).zero_()counts.scatter_add_(-1,samples,torch.ones_like(samples))returncounts.type_as(self.probs)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.