Shortcuts

Source code for torch.distributions.lowrank_multivariate_normal

# mypy: allow-untyped-defs
import math

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
from torch.distributions.utils import _standard_normal, lazy_property

__all__ = ["LowRankMultivariateNormal"]


def _batch_capacitance_tril(W, D):
    r"""
    Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
    and a batch of vectors :math:`D`.
    """
    m = W.size(-1)
    Wt_Dinv = W.mT / D.unsqueeze(-2)
    K = torch.matmul(Wt_Dinv, W).contiguous()
    K.view(-1, m * m)[:, :: m + 1] += 1  # add identity matrix to K
    return torch.linalg.cholesky(K)


def _batch_lowrank_logdet(W, D, capacitance_tril):
    r"""
    Uses "matrix determinant lemma"::
        log|W @ W.T + D| = log|C| + log|D|,
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
    the log determinant.
    """
    return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
        -1
    )


def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    r"""
    Uses "Woodbury matrix identity"::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
    """
    Wt_Dinv = W.mT / D.unsqueeze(-2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = (x.pow(2) / D).sum(-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2


[docs]class LowRankMultivariateNormal(Distribution): r""" Creates a multivariate normal distribution with covariance matrix having a low-rank form parameterized by :attr:`cov_factor` and :attr:`cov_diag`:: covariance_matrix = cov_factor @ cov_factor.T + cov_diag Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` tensor([-0.2102, -0.5429]) Args: loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape `batch_shape + event_shape + (rank,)` cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape `batch_shape + event_shape` Note: The computation for determinant and inverse of covariance matrix is avoided when `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_. Thanks to these formulas, we just need to compute the determinant and inverse of the small size "capacitance" matrix:: capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """ arg_constraints = { "loc": constraints.real_vector, "cov_factor": constraints.independent(constraints.real, 2), "cov_diag": constraints.independent(constraints.positive, 1), } support = constraints.real_vector has_rsample = True def __init__(self, loc, cov_factor, cov_diag, validate_args=None): if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") event_shape = loc.shape[-1:] if cov_factor.dim() < 2: raise ValueError( "cov_factor must be at least two-dimensional, " "with optional leading batch dimensions" ) if cov_factor.shape[-2:-1] != event_shape: raise ValueError( f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m" ) if cov_diag.shape[-1:] != event_shape: raise ValueError( f"cov_diag must be a batch of vectors with shape {event_shape}" ) loc_ = loc.unsqueeze(-1) cov_diag_ = cov_diag.unsqueeze(-1) try: loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors( loc_, cov_factor, cov_diag_ ) except RuntimeError as e: raise ValueError( f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}" ) from e self.loc = loc_[..., 0] self.cov_diag = cov_diag_[..., 0] batch_shape = self.loc.shape[:-1] self._unbroadcasted_cov_factor = cov_factor self._unbroadcasted_cov_diag = cov_diag self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag) super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(LowRankMultivariateNormal, _instance) batch_shape = torch.Size(batch_shape) loc_shape = batch_shape + self.event_shape new.loc = self.loc.expand(loc_shape) new.cov_diag = self.cov_diag.expand(loc_shape) new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:]) new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag new._capacitance_tril = self._capacitance_tril super(LowRankMultivariateNormal, new).__init__( batch_shape, self.event_shape, validate_args=False ) new._validate_args = self._validate_args return new
@property def mean(self): return self.loc @property def mode(self): return self.loc @lazy_property def variance(self): return ( self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag ).expand(self._batch_shape + self._event_shape) @lazy_property def scale_tril(self): # The following identity is used to increase the numerically computation stability # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, # hence it is well-conditioned and safe to take Cholesky decomposition. n = self._event_shape[0] cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1) Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous() K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K) return scale_tril.expand( self._batch_shape + self._event_shape + self._event_shape ) @lazy_property def covariance_matrix(self): covariance_matrix = torch.matmul( self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT ) + torch.diag_embed(self._unbroadcasted_cov_diag) return covariance_matrix.expand( self._batch_shape + self._event_shape + self._event_shape ) @lazy_property def precision_matrix(self): # We use "Woodbury matrix identity" to take advantage of low rank form:: # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) # where :math:`C` is the capacitance matrix. Wt_Dinv = ( self._unbroadcasted_cov_factor.mT / self._unbroadcasted_cov_diag.unsqueeze(-2) ) A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False) precision_matrix = ( torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A ) return precision_matrix.expand( self._batch_shape + self._event_shape + self._event_shape )
[docs] def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) W_shape = shape[:-1] + self.cov_factor.shape[-1:] eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device) eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) return ( self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W) + self._unbroadcasted_cov_diag.sqrt() * eps_D )
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) diff = value - self.loc M = _batch_lowrank_mahalanobis( self._unbroadcasted_cov_factor, self._unbroadcasted_cov_diag, diff, self._capacitance_tril, ) log_det = _batch_lowrank_logdet( self._unbroadcasted_cov_factor, self._unbroadcasted_cov_diag, self._capacitance_tril, ) return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
[docs] def entropy(self): log_det = _batch_lowrank_logdet( self._unbroadcasted_cov_factor, self._unbroadcasted_cov_diag, self._capacitance_tril, ) H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det) if len(self._batch_shape) == 0: return H else: return H.expand(self._batch_shape)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources