[docs]classContinuousBernoulli(ExponentialFamily):r""" Creates a continuous Bernoulli distribution parameterized by :attr:`probs` or :attr:`logits` (but not both). The distribution is supported in [0, 1] and parameterized by 'probs' (in (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs' does not correspond to a probability and 'logits' does not correspond to log-odds, but the same names are used due to the similarity with the Bernoulli. See [1] for more details. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> m = ContinuousBernoulli(torch.tensor([0.3])) >>> m.sample() tensor([ 0.2538]) Args: probs (Number, Tensor): (0,1) valued parameters logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs' [1] The continuous Bernoulli: fixing a pervasive error in variational autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. https://arxiv.org/abs/1907.06845 """arg_constraints={"probs":constraints.unit_interval,"logits":constraints.real}support=constraints.unit_interval_mean_carrier_measure=0has_rsample=Truedef__init__(self,probs=None,logits=None,lims=(0.499,0.501),validate_args=None):if(probsisNone)==(logitsisNone):raiseValueError("Either `probs` or `logits` must be specified, but not both.")ifprobsisnotNone:is_scalar=isinstance(probs,Number)(self.probs,)=broadcast_all(probs)# validate 'probs' here if necessary as it is later clamped for numerical stability# close to 0 and 1, later on; otherwise the clamped 'probs' would always passifvalidate_argsisnotNone:ifnotself.arg_constraints["probs"].check(self.probs).all():raiseValueError("The parameter probs has invalid values")self.probs=clamp_probs(self.probs)else:is_scalar=isinstance(logits,Number)(self.logits,)=broadcast_all(logits)self._param=self.probsifprobsisnotNoneelseself.logitsifis_scalar:batch_shape=torch.Size()else:batch_shape=self._param.size()self._lims=limssuper().__init__(batch_shape,validate_args=validate_args)
def_new(self,*args,**kwargs):returnself._param.new(*args,**kwargs)def_outside_unstable_region(self):returntorch.max(torch.le(self.probs,self._lims[0]),torch.gt(self.probs,self._lims[1]))def_cut_probs(self):returntorch.where(self._outside_unstable_region(),self.probs,self._lims[0]*torch.ones_like(self.probs),)def_cont_bern_log_norm(self):"""computes the log normalizing constant as a function of the 'probs' parameter"""cut_probs=self._cut_probs()cut_probs_below_half=torch.where(torch.le(cut_probs,0.5),cut_probs,torch.zeros_like(cut_probs))cut_probs_above_half=torch.where(torch.ge(cut_probs,0.5),cut_probs,torch.ones_like(cut_probs))log_norm=torch.log(torch.abs(torch.log1p(-cut_probs)-torch.log(cut_probs)))-torch.where(torch.le(cut_probs,0.5),torch.log1p(-2.0*cut_probs_below_half),torch.log(2.0*cut_probs_above_half-1.0),)x=torch.pow(self.probs-0.5,2)taylor=math.log(2.0)+(4.0/3.0+104.0/45.0*x)*xreturntorch.where(self._outside_unstable_region(),log_norm,taylor)@propertydefmean(self):cut_probs=self._cut_probs()mus=cut_probs/(2.0*cut_probs-1.0)+1.0/(torch.log1p(-cut_probs)-torch.log(cut_probs))x=self.probs-0.5taylor=0.5+(1.0/3.0+16.0/45.0*torch.pow(x,2))*xreturntorch.where(self._outside_unstable_region(),mus,taylor)@propertydefstddev(self):returntorch.sqrt(self.variance)@propertydefvariance(self):cut_probs=self._cut_probs()vars=cut_probs*(cut_probs-1.0)/torch.pow(1.0-2.0*cut_probs,2)+1.0/torch.pow(torch.log1p(-cut_probs)-torch.log(cut_probs),2)x=torch.pow(self.probs-0.5,2)taylor=1.0/12.0-(1.0/15.0-128.0/945.0*x)*xreturntorch.where(self._outside_unstable_region(),vars,taylor)@lazy_propertydeflogits(self):returnprobs_to_logits(self.probs,is_binary=True)@lazy_propertydefprobs(self):returnclamp_probs(logits_to_probs(self.logits,is_binary=True))@propertydefparam_shape(self):returnself._param.size()
@propertydef_natural_params(self):return(self.logits,)def_log_normalizer(self,x):"""computes the log normalizing constant as a function of the natural parameter"""out_unst_reg=torch.max(torch.le(x,self._lims[0]-0.5),torch.gt(x,self._lims[1]-0.5))cut_nat_params=torch.where(out_unst_reg,x,(self._lims[0]-0.5)*torch.ones_like(x))log_norm=torch.log(torch.abs(torch.exp(cut_nat_params)-1.0))-torch.log(torch.abs(cut_nat_params))taylor=0.5*x+torch.pow(x,2)/24.0-torch.pow(x,4)/2880.0returntorch.where(out_unst_reg,log_norm,taylor)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.