Shortcuts

Source code for torch.distributions.lkj_cholesky

"""
This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).

Original copyright notice:

# Copyright: Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""

import math

import torch
from torch.distributions import constraints, Beta
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all


[docs]class LKJCholesky(Distribution): r""" LKJ distribution for lower Cholesky factor of correlation matrices. The distribution is controlled by ``concentration`` parameter :math:`\eta` to make the probability of the correlation matrix :math:`M` generated from a Cholesky factor propotional to :math:`\det(M)^{\eta - 1}`. Because of that, when ``concentration == 1``, we have a uniform distribution over Cholesky factors of correlation matrices. Note that this distribution samples the Cholesky factor of correlation matrices and not the correlation matrices themselves and thereby differs slightly from the derivations in [1] for the `LKJCorr` distribution. For sampling, this uses the Onion method from [1] Section 3. L ~ LKJCholesky(dim, concentration) X = L @ L' ~ LKJCorr(dim, concentration) Example:: >>> l = LKJCholesky(3, 0.5) >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix tensor([[ 1.0000, 0.0000, 0.0000], [ 0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]]) Args: dimension (dim): dimension of the matrices concentration (float or Tensor): concentration/shape parameter of the distribution (often referred to as eta) **References** [1] `Generating random correlation matrices based on vines and extended onion method`, Daniel Lewandowski, Dorota Kurowicka, Harry Joe. """ arg_constraints = {'concentration': constraints.positive} support = constraints.corr_cholesky def __init__(self, dim, concentration=1., validate_args=None): if dim < 2: raise ValueError(f'Expected dim to be an integer greater than or equal to 2. Found dim={dim}.') self.dim = dim self.concentration, = broadcast_all(concentration) batch_shape = self.concentration.size() event_shape = torch.Size((dim, dim)) # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1]. marginal_conc = self.concentration + 0.5 * (self.dim - 2) offset = torch.arange(self.dim - 1, dtype=self.concentration.dtype, device=self.concentration.device) offset = torch.cat([offset.new_zeros((1,)), offset]) beta_conc1 = offset + 0.5 beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset self._beta = Beta(beta_conc1, beta_conc0) super(LKJCholesky, self).__init__(batch_shape, event_shape, validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(LKJCholesky, _instance) batch_shape = torch.Size(batch_shape) new.dim = self.dim new.concentration = self.concentration.expand(batch_shape) new._beta = self._beta.expand(batch_shape + (self.dim,)) super(LKJCholesky, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new
[docs] def sample(self, sample_shape=torch.Size()): # This uses the Onion method, but there are a few differences from [1] Sec. 3.2: # - This vectorizes the for loop and also works for heterogeneous eta. # - Same algorithm generalizes to n=1. # - The procedure is simplified since we are sampling the cholesky factor of # the correlation matrix instead of the correlation matrix itself. As such, # we only need to generate `w`. y = self._beta.sample(sample_shape).unsqueeze(-1) u_normal = torch.randn(self._extended_shape(sample_shape), dtype=y.dtype, device=y.device).tril(-1) u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True) # Replace NaNs in first row u_hypersphere[..., 0, :].fill_(0.) w = torch.sqrt(y) * u_hypersphere # Fill diagonal elements; clamp for numerical stability eps = torch.finfo(w.dtype).tiny diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt() w += torch.diag_embed(diag_elems) return w
[docs] def log_prob(self, value): # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html # The probability of a correlation matrix is proportional to # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) # Additionally, the Jacobian of the transformation from Cholesky factor to # correlation matrix is: # prod(L_ii ^ (D - i)) # So the probability of a Cholesky factor is propotional to # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i) # with order_i = 2 * concentration - 2 + D - i if self._validate_args: self._validate_sample(value) diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:] order = torch.arange(2, self.dim + 1) order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1) # Compute normalization constant (page 1999 of [1]) dm1 = self.dim - 1 alpha = self.concentration + 0.5 * dm1 denominator = torch.lgamma(alpha) * dm1 numerator = torch.mvlgamma(alpha - 0.5, dm1) # pi_constant in [1] is D * (D - 1) / 4 * log(pi) # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi) # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2 pi_constant = 0.5 * dm1 * math.log(math.pi) normalize_term = pi_constant + numerator - denominator return unnormalized_log_pdf - normalize_term

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources