# mypy: allow-untyped-defsimportmathimporttorchimporttorch.jitfromtorch.distributionsimportconstraintsfromtorch.distributions.distributionimportDistributionfromtorch.distributions.utilsimportbroadcast_all,lazy_property__all__=["VonMises"]def_eval_poly(y,coef):coef=list(coef)result=coef.pop()whilecoef:result=coef.pop()+y*resultreturnresult_I0_COEF_SMALL=[1.0,3.5156229,3.0899424,1.2067492,0.2659732,0.360768e-1,0.45813e-2,]_I0_COEF_LARGE=[0.39894228,0.1328592e-1,0.225319e-2,-0.157565e-2,0.916281e-2,-0.2057706e-1,0.2635537e-1,-0.1647633e-1,0.392377e-2,]_I1_COEF_SMALL=[0.5,0.87890594,0.51498869,0.15084934,0.2658733e-1,0.301532e-2,0.32411e-3,]_I1_COEF_LARGE=[0.39894228,-0.3988024e-1,-0.362018e-2,0.163801e-2,-0.1031555e-1,0.2282967e-1,-0.2895312e-1,0.1787654e-1,-0.420059e-2,]_COEF_SMALL=[_I0_COEF_SMALL,_I1_COEF_SMALL]_COEF_LARGE=[_I0_COEF_LARGE,_I1_COEF_LARGE]def_log_modified_bessel_fn(x,order=0):""" Returns ``log(I_order(x))`` for ``x > 0``, where `order` is either 0 or 1. """assertorder==0ororder==1# compute small solutiony=x/3.75y=y*ysmall=_eval_poly(y,_COEF_SMALL[order])iforder==1:small=x.abs()*smallsmall=small.log()# compute large solutiony=3.75/xlarge=x-0.5*x.log()+_eval_poly(y,_COEF_LARGE[order]).log()result=torch.where(x<3.75,small,large)returnresult@torch.jit.script_if_tracingdef_rejection_sample(loc,concentration,proposal_r,x):done=torch.zeros(x.shape,dtype=torch.bool,device=loc.device)whilenotdone.all():u=torch.rand((3,)+x.shape,dtype=loc.dtype,device=loc.device)u1,u2,u3=u.unbind()z=torch.cos(math.pi*u1)f=(1+proposal_r*z)/(proposal_r+z)c=concentration*(proposal_r-f)accept=((c*(2-c)-u2)>0)|((c/u2).log()+1-c>=0)ifaccept.any():x=torch.where(accept,(u3-0.5).sign()*f.acos(),x)done=done|acceptreturn(x+math.pi+loc)%(2*math.pi)-math.pi
[docs]classVonMises(Distribution):""" A circular von Mises distribution. This implementation uses polar coordinates. The ``loc`` and ``value`` args can be any real number (to facilitate unconstrained optimization), but are interpreted as angles modulo 2 pi. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # von Mises distributed with loc=1 and concentration=1 tensor([1.9777]) :param torch.Tensor loc: an angle in radians. :param torch.Tensor concentration: concentration parameter """arg_constraints={"loc":constraints.real,"concentration":constraints.positive}support=constraints.realhas_rsample=Falsedef__init__(self,loc,concentration,validate_args=None):self.loc,self.concentration=broadcast_all(loc,concentration)batch_shape=self.loc.shapeevent_shape=torch.Size()super().__init__(batch_shape,event_shape,validate_args)
@lazy_propertydef_loc(self):returnself.loc.to(torch.double)@lazy_propertydef_concentration(self):returnself.concentration.to(torch.double)@lazy_propertydef_proposal_r(self):kappa=self._concentrationtau=1+(1+4*kappa**2).sqrt()rho=(tau-(2*tau).sqrt())/(2*kappa)_proposal_r=(1+rho**2)/(2*rho)# second order Taylor expansion around 0 for small kappa_proposal_r_taylor=1/kappa+kappareturntorch.where(kappa<1e-5,_proposal_r_taylor,_proposal_r)
[docs]@torch.no_grad()defsample(self,sample_shape=torch.Size()):""" The sampling algorithm for the von Mises distribution is based on the following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the von Mises distribution." Applied Statistics (1979): 152-157. Sampling is always done in double precision internally to avoid a hang in _rejection_sample() for small values of the concentration, which starts to happen for single precision around 1e-4 (see issue #88443). """shape=self._extended_shape(sample_shape)x=torch.empty(shape,dtype=self._loc.dtype,device=self.loc.device)return_rejection_sample(self._loc,self._concentration,self._proposal_r,x).to(self.loc.dtype)
@propertydefmean(self):""" The provided mean is the circular one. """returnself.loc@propertydefmode(self):returnself.loc@lazy_propertydefvariance(self):""" The provided variance is the circular one. """return(1-(_log_modified_bessel_fn(self.concentration,order=1)-_log_modified_bessel_fn(self.concentration,order=0)).exp())
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.