"""This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).Original copyright notice:# Copyright: Contributors to the Pyro project.# SPDX-License-Identifier: Apache-2.0"""importmathimporttorchfromtorch.distributionsimportBeta,constraintsfromtorch.distributions.distributionimportDistributionfromtorch.distributions.utilsimportbroadcast_all__all__=["LKJCholesky"]
[docs]classLKJCholesky(Distribution):r""" LKJ distribution for lower Cholesky factor of correlation matrices. The distribution is controlled by ``concentration`` parameter :math:`\eta` to make the probability of the correlation matrix :math:`M` generated from a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that, when ``concentration == 1``, we have a uniform distribution over Cholesky factors of correlation matrices:: L ~ LKJCholesky(dim, concentration) X = L @ L' ~ LKJCorr(dim, concentration) Note that this distribution samples the Cholesky factor of correlation matrices and not the correlation matrices themselves and thereby differs slightly from the derivations in [1] for the `LKJCorr` distribution. For sampling, this uses the Onion method from [1] Section 3. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterinistic") >>> l = LKJCholesky(3, 0.5) >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix tensor([[ 1.0000, 0.0000, 0.0000], [ 0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]]) Args: dimension (dim): dimension of the matrices concentration (float or Tensor): concentration/shape parameter of the distribution (often referred to as eta) **References** [1] `Generating random correlation matrices based on vines and extended onion method` (2009), Daniel Lewandowski, Dorota Kurowicka, Harry Joe. Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 """arg_constraints={"concentration":constraints.positive}support=constraints.corr_choleskydef__init__(self,dim,concentration=1.0,validate_args=None):ifdim<2:raiseValueError(f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}.")self.dim=dim(self.concentration,)=broadcast_all(concentration)batch_shape=self.concentration.size()event_shape=torch.Size((dim,dim))# This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].marginal_conc=self.concentration+0.5*(self.dim-2)offset=torch.arange(self.dim-1,dtype=self.concentration.dtype,device=self.concentration.device,)offset=torch.cat([offset.new_zeros((1,)),offset])beta_conc1=offset+0.5beta_conc0=marginal_conc.unsqueeze(-1)-0.5*offsetself._beta=Beta(beta_conc1,beta_conc0)super().__init__(batch_shape,event_shape,validate_args)
[docs]defsample(self,sample_shape=torch.Size()):# This uses the Onion method, but there are a few differences from [1] Sec. 3.2:# - This vectorizes the for loop and also works for heterogeneous eta.# - Same algorithm generalizes to n=1.# - The procedure is simplified since we are sampling the cholesky factor of# the correlation matrix instead of the correlation matrix itself. As such,# we only need to generate `w`.y=self._beta.sample(sample_shape).unsqueeze(-1)u_normal=torch.randn(self._extended_shape(sample_shape),dtype=y.dtype,device=y.device).tril(-1)u_hypersphere=u_normal/u_normal.norm(dim=-1,keepdim=True)# Replace NaNs in first rowu_hypersphere[...,0,:].fill_(0.0)w=torch.sqrt(y)*u_hypersphere# Fill diagonal elements; clamp for numerical stabilityeps=torch.finfo(w.dtype).tinydiag_elems=torch.clamp(1-torch.sum(w**2,dim=-1),min=eps).sqrt()w+=torch.diag_embed(diag_elems)returnw
[docs]deflog_prob(self,value):# See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html# The probability of a correlation matrix is proportional to# determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))# Additionally, the Jacobian of the transformation from Cholesky factor to# correlation matrix is:# prod(L_ii ^ (D - i))# So the probability of a Cholesky factor is propotional to# prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)# with order_i = 2 * concentration - 2 + D - iifself._validate_args:self._validate_sample(value)diag_elems=value.diagonal(dim1=-1,dim2=-2)[...,1:]order=torch.arange(2,self.dim+1,device=self.concentration.device)order=2*(self.concentration-1).unsqueeze(-1)+self.dim-orderunnormalized_log_pdf=torch.sum(order*diag_elems.log(),dim=-1)# Compute normalization constant (page 1999 of [1])dm1=self.dim-1alpha=self.concentration+0.5*dm1denominator=torch.lgamma(alpha)*dm1numerator=torch.mvlgamma(alpha-0.5,dm1)# pi_constant in [1] is D * (D - 1) / 4 * log(pi)# pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)# hence, we need to add a pi_constant = (D - 1) * log(pi) / 2pi_constant=0.5*dm1*math.log(math.pi)normalize_term=pi_constant+numerator-denominatorreturnunnormalized_log_pdf-normalize_term
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.