Shortcuts

# Source code for torch.distributions.lowrank_multivariate_normal

import math

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
from torch.distributions.utils import _standard_normal, lazy_property

def _batch_capacitance_tril(W, D):
r"""
Computes Cholesky of :math:I + W.T @ inv(D) @ W for a batch of matrices :math:W
and a batch of vectors :math:D.
"""
m = W.size(-1)
Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
K = torch.matmul(Wt_Dinv, W).contiguous()
K.view(-1, m * m)[:, ::m + 1] += 1  # add identity matrix to K

def _batch_lowrank_logdet(W, D, capacitance_tril):
r"""
Uses "matrix determinant lemma"::
log|W @ W.T + D| = log|C| + log|D|,
where :math:C is the capacitance matrix :math:I + W.T @ inv(D) @ W, to compute
the log determinant.
"""
return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)

def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
r"""
Uses "Woodbury matrix identity"::
inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
where :math:C is the capacitance matrix :math:I + W.T @ inv(D) @ W, to compute the squared
Mahalanobis distance :math:x.T @ inv(W @ W.T + D) @ x.
"""
Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
mahalanobis_term1 = (x.pow(2) / D).sum(-1)
mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
return mahalanobis_term1 - mahalanobis_term2

[docs]class LowRankMultivariateNormal(Distribution):
r"""
Creates a multivariate normal distribution with covariance matrix having a low-rank form
parameterized by :attr:cov_factor and :attr:cov_diag::
covariance_matrix = cov_factor @ cov_factor.T + cov_diag

Example:

>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
>>> m.sample()  # normally distributed with mean=[0,0], cov_factor=[1,0], cov_diag=[1,1]
tensor([-0.2102, -0.5429])

Args:
loc (Tensor): mean of the distribution with shape batch_shape + event_shape
cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
batch_shape + event_shape + (rank,)
cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
batch_shape + event_shape

Note:
The computation for determinant and inverse of covariance matrix is avoided when
cov_factor.shape[1] << cov_factor.shape[0] thanks to Woodbury matrix identity
<https://en.wikipedia.org/wiki/Woodbury_matrix_identity>_ and
matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>_.
Thanks to these formulas, we just need to compute the determinant and inverse of
the small size "capacitance" matrix::
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
"""
arg_constraints = {"loc": constraints.real,
"cov_factor": constraints.real,
"cov_diag": constraints.positive}
support = constraints.real
has_rsample = True

def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")
event_shape = loc.shape[-1:]
if cov_factor.dim() < 2:
raise ValueError("cov_factor must be at least two-dimensional, "
"with optional leading batch dimensions")
if cov_factor.shape[-2:-1] != event_shape:
raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
.format(event_shape[0]))
if cov_diag.shape[-1:] != event_shape:
raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))

loc_ = loc.unsqueeze(-1)
cov_diag_ = cov_diag.unsqueeze(-1)
try:
loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
except RuntimeError:
raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
.format(loc.shape, cov_factor.shape, cov_diag.shape))
self.loc = loc_[..., 0]
self.cov_diag = cov_diag_[..., 0]
batch_shape = self.loc.shape[:-1]

self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
validate_args=validate_args)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
batch_shape = torch.Size(batch_shape)
loc_shape = batch_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
new.cov_diag = self.cov_diag.expand(loc_shape)
new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
new._capacitance_tril = self._capacitance_tril
super(LowRankMultivariateNormal, new).__init__(batch_shape,
self.event_shape,
validate_args=False)
new._validate_args = self._validate_args
return new

@property
def mean(self):
return self.loc

[docs]    @lazy_property
def variance(self):
+ self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)

[docs]    @lazy_property
def scale_tril(self):
# The following identity is used to increase the numerically computation stability
# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
#     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
# hence it is well-conditioned and safe to take Cholesky decomposition.
n = self._event_shape[0]
Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
K.view(-1, n * n)[:, ::n + 1] += 1  # add identity matrix to K
scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)

[docs]    @lazy_property
def covariance_matrix(self):
return covariance_matrix.expand(self._batch_shape + self._event_shape +
self._event_shape)

[docs]    @lazy_property
def precision_matrix(self):
# We use "Woodbury matrix identity" to take advantage of low rank form::
#     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
# where :math:C is the capacitance matrix.
Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
A = torch.triangular_solve(Wt_Dinv, self._capacitance_tril, upper=False)[0]
- torch.matmul(A.transpose(-1, -2), A))
return precision_matrix.expand(self._batch_shape + self._event_shape +
self._event_shape)

[docs]    def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
W_shape = shape[:-1] + self.cov_factor.shape[-1:]
eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
+ self._unbroadcasted_cov_diag.sqrt() * eps_D)

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diff = value - self.loc
diff,
self._capacitance_tril)
self._capacitance_tril)
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)

[docs]    def entropy(self):
self._capacitance_tril)
H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)


## Docs

Access comprehensive developer documentation for PyTorch

View Docs

## Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

## Resources

Find development resources and get your questions answered

View Resources