importmathimportwarningsfromnumbersimportNumberfromtypingimportOptional,Unionimporttorchfromtorchimportnanfromtorch.distributionsimportconstraintsfromtorch.distributions.exp_familyimportExponentialFamilyfromtorch.distributions.multivariate_normalimport_precision_to_scale_trilfromtorch.distributions.utilsimportlazy_property__all__=["Wishart"]_log_2=math.log(2)def_mvdigamma(x:torch.Tensor,p:int)->torch.Tensor:assertx.gt((p-1)/2).all(),"Wrong domain for multivariate digamma function."returntorch.digamma(x.unsqueeze(-1)-torch.arange(p,dtype=x.dtype,device=x.device).div(2).expand(x.shape+(-1,))).sum(-1)def_clamp_above_eps(x:torch.Tensor)->torch.Tensor:# We assume positive input for this functionreturnx.clamp(min=torch.finfo(x.dtype).eps)
[docs]classWishart(ExponentialFamily):r""" Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` Example: >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional") >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) >>> m.sample() # Wishart distributed with mean=`df * I` and >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j Args: df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 covariance_matrix (Tensor): positive-definite covariance matrix precision_matrix (Tensor): positive-definite precision matrix scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal Note: Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or :attr:`scale_tril` can be specified. Using :attr:`scale_tril` will be more efficient: all computations internally are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or :attr:`precision_matrix` is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition. 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1] **References** [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`. [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`. [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`. [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203. [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. """arg_constraints={"covariance_matrix":constraints.positive_definite,"precision_matrix":constraints.positive_definite,"scale_tril":constraints.lower_cholesky,"df":constraints.greater_than(0),}support=constraints.positive_definitehas_rsample=True_mean_carrier_measure=0def__init__(self,df:Union[torch.Tensor,Number],covariance_matrix:Optional[torch.Tensor]=None,precision_matrix:Optional[torch.Tensor]=None,scale_tril:Optional[torch.Tensor]=None,validate_args=None,):assert(covariance_matrixisnotNone)+(scale_trilisnotNone)+(precision_matrixisnotNone)==1,"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."param=next(pforpin(covariance_matrix,precision_matrix,scale_tril)ifpisnotNone)ifparam.dim()<2:raiseValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")ifisinstance(df,Number):batch_shape=torch.Size(param.shape[:-2])self.df=torch.tensor(df,dtype=param.dtype,device=param.device)else:batch_shape=torch.broadcast_shapes(param.shape[:-2],df.shape)self.df=df.expand(batch_shape)event_shape=param.shape[-2:]ifself.df.le(event_shape[-1]-1).any():raiseValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.")ifscale_trilisnotNone:self.scale_tril=param.expand(batch_shape+(-1,-1))elifcovariance_matrixisnotNone:self.covariance_matrix=param.expand(batch_shape+(-1,-1))elifprecision_matrixisnotNone:self.precision_matrix=param.expand(batch_shape+(-1,-1))self.arg_constraints["df"]=constraints.greater_than(event_shape[-1]-1)ifself.df.lt(event_shape[-1]).any():warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.")super().__init__(batch_shape,event_shape,validate_args=validate_args)self._batch_dims=[-(x+1)forxinrange(len(self._batch_shape))]ifscale_trilisnotNone:self._unbroadcasted_scale_tril=scale_trilelifcovariance_matrixisnotNone:self._unbroadcasted_scale_tril=torch.linalg.cholesky(covariance_matrix)else:# precision_matrix is not Noneself._unbroadcasted_scale_tril=_precision_to_scale_tril(precision_matrix)# Chi2 distribution is needed for Bartlett decomposition samplingself._dist_chi2=torch.distributions.chi2.Chi2(df=(self.df.unsqueeze(-1)-torch.arange(self._event_shape[-1],dtype=self._unbroadcasted_scale_tril.dtype,device=self._unbroadcasted_scale_tril.device,).expand(batch_shape+(-1,))))
[docs]defexpand(self,batch_shape,_instance=None):new=self._get_checked_instance(Wishart,_instance)batch_shape=torch.Size(batch_shape)cov_shape=batch_shape+self.event_shapenew._unbroadcasted_scale_tril=self._unbroadcasted_scale_tril.expand(cov_shape)new.df=self.df.expand(batch_shape)new._batch_dims=[-(x+1)forxinrange(len(batch_shape))]if"covariance_matrix"inself.__dict__:new.covariance_matrix=self.covariance_matrix.expand(cov_shape)if"scale_tril"inself.__dict__:new.scale_tril=self.scale_tril.expand(cov_shape)if"precision_matrix"inself.__dict__:new.precision_matrix=self.precision_matrix.expand(cov_shape)# Chi2 distribution is needed for Bartlett decomposition samplingnew._dist_chi2=torch.distributions.chi2.Chi2(df=(new.df.unsqueeze(-1)-torch.arange(self.event_shape[-1],dtype=new._unbroadcasted_scale_tril.dtype,device=new._unbroadcasted_scale_tril.device,).expand(batch_shape+(-1,))))super(Wishart,new).__init__(batch_shape,self.event_shape,validate_args=False)new._validate_args=self._validate_argsreturnnew
@lazy_propertydefscale_tril(self):returnself._unbroadcasted_scale_tril.expand(self._batch_shape+self._event_shape)@lazy_propertydefcovariance_matrix(self):return(self._unbroadcasted_scale_tril@self._unbroadcasted_scale_tril.transpose(-2,-1)).expand(self._batch_shape+self._event_shape)@lazy_propertydefprecision_matrix(self):identity=torch.eye(self._event_shape[-1],device=self._unbroadcasted_scale_tril.device,dtype=self._unbroadcasted_scale_tril.dtype,)returntorch.cholesky_solve(identity,self._unbroadcasted_scale_tril).expand(self._batch_shape+self._event_shape)@propertydefmean(self):returnself.df.view(self._batch_shape+(1,1))*self.covariance_matrix@propertydefmode(self):factor=self.df-self.covariance_matrix.shape[-1]-1factor[factor<=0]=nanreturnfactor.view(self._batch_shape+(1,1))*self.covariance_matrix@propertydefvariance(self):V=self.covariance_matrix# has shape (batch_shape x event_shape)diag_V=V.diagonal(dim1=-2,dim2=-1)returnself.df.view(self._batch_shape+(1,1))*(V.pow(2)+torch.einsum("...i,...j->...ij",diag_V,diag_V))def_bartlett_sampling(self,sample_shape=torch.Size()):p=self._event_shape[-1]# has singleton shape# Implemented Sampling using Bartlett decompositionnoise=_clamp_above_eps(self._dist_chi2.rsample(sample_shape).sqrt()).diag_embed(dim1=-2,dim2=-1)i,j=torch.tril_indices(p,p,offset=-1)noise[...,i,j]=torch.randn(torch.Size(sample_shape)+self._batch_shape+(int(p*(p-1)/2),),dtype=noise.dtype,device=noise.device,)chol=self._unbroadcasted_scale_tril@noisereturnchol@chol.transpose(-2,-1)
[docs]defrsample(self,sample_shape=torch.Size(),max_try_correction=None):r""" .. warning:: In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. Several tries to correct singular samples are performed by default, but it may end up returning singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`. In those cases, the user should validate the samples and either fix the value of `df` or adjust `max_try_correction` value for argument in `.rsample` accordingly. """ifmax_try_correctionisNone:max_try_correction=3iftorch._C._get_tracing_state()else10sample_shape=torch.Size(sample_shape)sample=self._bartlett_sampling(sample_shape)# Below part is to improve numerical stability temporally and should be removed in the futureis_singular=self.support.check(sample)ifself._batch_shape:is_singular=is_singular.amax(self._batch_dims)iftorch._C._get_tracing_state():# Less optimized version for JITfor_inrange(max_try_correction):sample_new=self._bartlett_sampling(sample_shape)sample=torch.where(is_singular,sample_new,sample)is_singular=~self.support.check(sample)ifself._batch_shape:is_singular=is_singular.amax(self._batch_dims)else:# More optimized version with data-dependent control flow.ifis_singular.any():warnings.warn("Singular sample detected.")for_inrange(max_try_correction):sample_new=self._bartlett_sampling(is_singular[is_singular].shape)sample[is_singular]=sample_newis_singular_new=~self.support.check(sample_new)ifself._batch_shape:is_singular_new=is_singular_new.amax(self._batch_dims)is_singular[is_singular.clone()]=is_singular_newifnotis_singular.any():breakreturnsample
[docs]deflog_prob(self,value):ifself._validate_args:self._validate_sample(value)nu=self.df# has shape (batch_shape)p=self._event_shape[-1]# has singleton shapereturn(-nu*(p*_log_2/2+self._unbroadcasted_scale_tril.diagonal(dim1=-2,dim2=-1).log().sum(-1))-torch.mvlgamma(nu/2,p=p)+(nu-p-1)/2*torch.linalg.slogdet(value).logabsdet-torch.cholesky_solve(value,self._unbroadcasted_scale_tril).diagonal(dim1=-2,dim2=-1).sum(dim=-1)/2)
[docs]defentropy(self):nu=self.df# has shape (batch_shape)p=self._event_shape[-1]# has singleton shapeV=self.covariance_matrix# has shape (batch_shape x event_shape)return((p+1)*(p*_log_2/2+self._unbroadcasted_scale_tril.diagonal(dim1=-2,dim2=-1).log().sum(-1))+torch.mvlgamma(nu/2,p=p)-(nu-p-1)/2*_mvdigamma(nu/2,p=p)+nu*p/2)
@propertydef_natural_params(self):nu=self.df# has shape (batch_shape)p=self._event_shape[-1]# has singleton shapereturn-self.precision_matrix/2,(nu-p-1)/2def_log_normalizer(self,x,y):p=self._event_shape[-1]return(y+(p+1)/2)*(-torch.linalg.slogdet(-2*x).logabsdet+_log_2*p)+torch.mvlgamma(y+(p+1)/2,p=p)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.