Source code for torch.distributions.relaxed_categorical
importtorchfromtorch.distributionsimportconstraintsfromtorch.distributions.categoricalimportCategoricalfromtorch.distributions.distributionimportDistributionfromtorch.distributions.transformed_distributionimportTransformedDistributionfromtorch.distributions.transformsimportExpTransformfromtorch.distributions.utilsimportbroadcast_all,clamp_probs__all__=["ExpRelaxedCategorical","RelaxedOneHotCategorical"]classExpRelaxedCategorical(Distribution):r""" Creates a ExpRelaxedCategorical parameterized by :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). Returns the log of a point in the simplex. Based on the interface to :class:`OneHotCategorical`. Implementation based on [1]. See also: :func:`torch.distributions.OneHotCategorical` Args: temperature (Tensor): relaxation temperature probs (Tensor): event probabilities logits (Tensor): unnormalized log probability for each event [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017) [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017) """arg_constraints={"probs":constraints.simplex,"logits":constraints.real_vector}support=(constraints.real_vector)# The true support is actually a submanifold of this.has_rsample=Truedef__init__(self,temperature,probs=None,logits=None,validate_args=None):self._categorical=Categorical(probs,logits)self.temperature=temperaturebatch_shape=self._categorical.batch_shapeevent_shape=self._categorical.param_shape[-1:]super().__init__(batch_shape,event_shape,validate_args=validate_args)defexpand(self,batch_shape,_instance=None):new=self._get_checked_instance(ExpRelaxedCategorical,_instance)batch_shape=torch.Size(batch_shape)new.temperature=self.temperaturenew._categorical=self._categorical.expand(batch_shape)super(ExpRelaxedCategorical,new).__init__(batch_shape,self.event_shape,validate_args=False)new._validate_args=self._validate_argsreturnnewdef_new(self,*args,**kwargs):returnself._categorical._new(*args,**kwargs)@propertydefparam_shape(self):returnself._categorical.param_shape@propertydeflogits(self):returnself._categorical.logits@propertydefprobs(self):returnself._categorical.probsdefrsample(self,sample_shape=torch.Size()):shape=self._extended_shape(sample_shape)uniforms=clamp_probs(torch.rand(shape,dtype=self.logits.dtype,device=self.logits.device))gumbels=-((-(uniforms.log())).log())scores=(self.logits+gumbels)/self.temperaturereturnscores-scores.logsumexp(dim=-1,keepdim=True)deflog_prob(self,value):K=self._categorical._num_eventsifself._validate_args:self._validate_sample(value)logits,value=broadcast_all(self.logits,value)log_scale=torch.full_like(self.temperature,float(K)).lgamma()-self.temperature.log().mul(-(K-1))score=logits-value.mul(self.temperature)score=(score-score.logsumexp(dim=-1,keepdim=True)).sum(-1)returnscore+log_scale
[docs]classRelaxedOneHotCategorical(TransformedDistribution):r""" Creates a RelaxedOneHotCategorical distribution parametrized by :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. This is a relaxed version of the :class:`OneHotCategorical` distribution, so its samples are on simplex, and are reparametrizable. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.4])) >>> m.sample() tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) Args: temperature (Tensor): relaxation temperature probs (Tensor): event probabilities logits (Tensor): unnormalized log probability for each event """arg_constraints={"probs":constraints.simplex,"logits":constraints.real_vector}support=constraints.simplexhas_rsample=Truedef__init__(self,temperature,probs=None,logits=None,validate_args=None):base_dist=ExpRelaxedCategorical(temperature,probs,logits,validate_args=validate_args)super().__init__(base_dist,ExpTransform(),validate_args=validate_args)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.