importmathimporttorchimporttorch.jitfromtorch.distributionsimportconstraintsfromtorch.distributions.distributionimportDistributionfromtorch.distributions.utilsimportbroadcast_all,lazy_property__all__=["VonMises"]def_eval_poly(y,coef):coef=list(coef)result=coef.pop()whilecoef:result=coef.pop()+y*resultreturnresult_I0_COEF_SMALL=[1.0,3.5156229,3.0899424,1.2067492,0.2659732,0.360768e-1,0.45813e-2,]_I0_COEF_LARGE=[0.39894228,0.1328592e-1,0.225319e-2,-0.157565e-2,0.916281e-2,-0.2057706e-1,0.2635537e-1,-0.1647633e-1,0.392377e-2,]_I1_COEF_SMALL=[0.5,0.87890594,0.51498869,0.15084934,0.2658733e-1,0.301532e-2,0.32411e-3,]_I1_COEF_LARGE=[0.39894228,-0.3988024e-1,-0.362018e-2,0.163801e-2,-0.1031555e-1,0.2282967e-1,-0.2895312e-1,0.1787654e-1,-0.420059e-2,]_COEF_SMALL=[_I0_COEF_SMALL,_I1_COEF_SMALL]_COEF_LARGE=[_I0_COEF_LARGE,_I1_COEF_LARGE]def_log_modified_bessel_fn(x,order=0):""" Returns ``log(I_order(x))`` for ``x > 0``, where `order` is either 0 or 1. """assertorder==0ororder==1# compute small solutiony=x/3.75y=y*ysmall=_eval_poly(y,_COEF_SMALL[order])iforder==1:small=x.abs()*smallsmall=small.log()# compute large solutiony=3.75/xlarge=x-0.5*x.log()+_eval_poly(y,_COEF_LARGE[order]).log()result=torch.where(x<3.75,small,large)returnresult@torch.jit.script_if_tracingdef_rejection_sample(loc,concentration,proposal_r,x):done=torch.zeros(x.shape,dtype=torch.bool,device=loc.device)whilenotdone.all():u=torch.rand((3,)+x.shape,dtype=loc.dtype,device=loc.device)u1,u2,u3=u.unbind()z=torch.cos(math.pi*u1)f=(1+proposal_r*z)/(proposal_r+z)c=concentration*(proposal_r-f)accept=((c*(2-c)-u2)>0)|((c/u2).log()+1-c>=0)ifaccept.any():x=torch.where(accept,(u3-0.5).sign()*f.acos(),x)done=done|acceptreturn(x+math.pi+loc)%(2*math.pi)-math.pi
[docs]classVonMises(Distribution):""" A circular von Mises distribution. This implementation uses polar coordinates. The ``loc`` and ``value`` args can be any real number (to facilitate unconstrained optimization), but are interpreted as angles modulo 2 pi. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # von Mises distributed with loc=1 and concentration=1 tensor([1.9777]) :param torch.Tensor loc: an angle in radians. :param torch.Tensor concentration: concentration parameter """arg_constraints={"loc":constraints.real,"concentration":constraints.positive}support=constraints.realhas_rsample=Falsedef__init__(self,loc,concentration,validate_args=None):self.loc,self.concentration=broadcast_all(loc,concentration)batch_shape=self.loc.shapeevent_shape=torch.Size()# Parameters for samplingtau=1+(1+4*self.concentration**2).sqrt()rho=(tau-(2*tau).sqrt())/(2*self.concentration)self._proposal_r=(1+rho**2)/(2*rho)super().__init__(batch_shape,event_shape,validate_args)
[docs]@torch.no_grad()defsample(self,sample_shape=torch.Size()):""" The sampling algorithm for the von Mises distribution is based on the following paper: Best, D. J., and Nicholas I. Fisher. "Efficient simulation of the von Mises distribution." Applied Statistics (1979): 152-157. """shape=self._extended_shape(sample_shape)x=torch.empty(shape,dtype=self.loc.dtype,device=self.loc.device)return_rejection_sample(self.loc,self.concentration,self._proposal_r,x)
@propertydefmean(self):""" The provided mean is the circular one. """returnself.loc@propertydefmode(self):returnself.loc@lazy_propertydefvariance(self):""" The provided variance is the circular one. """return(1-(_log_modified_bessel_fn(self.concentration,order=1)-_log_modified_bessel_fn(self.concentration,order=0)).exp())
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.