import math
from numbers import Number
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property
def _get_batch_shape(bmat, bvec):
r"""
Given a batch of matrices and a batch of vectors, compute the combined `batch_shape`.
"""
try:
vec_shape = torch._C._infer_size(bvec.shape, bmat.shape[:-1])
except RuntimeError:
raise ValueError("Incompatible batch shapes: vector {}, matrix {}".format(bvec.shape, bmat.shape))
return torch.Size(vec_shape[:-1])
def _batch_mv(bmat, bvec):
r"""
Performs a batched matrix-vector product, with compatible but different batch shapes.
This function takes as input `bmat`, containing :math:`n \times n` matrices, and
`bvec`, containing length :math:`n` vectors.
Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
to a batch shape. They are not necessarily assumed to have the same batch shape,
just ones which can be broadcasted.
"""
n = bvec.size(-1)
batch_shape = _get_batch_shape(bmat, bvec)
# to conform with `torch.bmm` interface, both bmat and bvec should have `.dim() == 3`
bmat = bmat.expand(batch_shape + (n, n)).reshape((-1, n, n))
bvec = bvec.unsqueeze(-1).expand(batch_shape + (n, 1)).reshape((-1, n, 1))
return torch.bmm(bmat, bvec).view(batch_shape + (n,))
def _batch_potrf_lower(bmat):
r"""
Applies a Cholesky decomposition to all matrices in a batch of arbitrary shape.
"""
n = bmat.size(-1)
cholesky = torch.stack([C.potrf(upper=False) for C in bmat.reshape((-1, n, n))])
return cholesky.view(bmat.shape)
def _batch_diag(bmat):
r"""
Returns the diagonals of a batch of square matrices.
"""
return bmat.reshape(bmat.shape[:-2] + (-1,))[..., ::bmat.size(-1) + 1]
def _batch_inverse(bmat):
r"""
Returns the inverses of a batch of square matrices.
"""
n = bmat.size(-1)
flat_bmat = bmat.reshape(-1, n, n)
flat_inv_bmat = torch.stack([m.inverse() for m in flat_bmat], 0)
return flat_inv_bmat.view(bmat.shape)
def _batch_mahalanobis(L, x):
r"""
Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
Accepts batches for both L and x.
"""
# TODO: use `torch.potrs` or similar once a backwards pass is implemented.
flat_L = L.unsqueeze(0).reshape((-1,) + L.shape[-2:])
L_inv = torch.stack([torch.inverse(Li.t()) for Li in flat_L]).view(L.shape)
return (x.unsqueeze(-1) * L_inv).sum(-2).pow(2.0).sum(-1)
[docs]class MultivariateNormal(Distribution):
r"""
Creates a multivariate normal (also called Gaussian) distribution
parameterized by a mean vector and a covariance matrix.
The multivariate normal distribution can be parameterized either
in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
or a positive definite precition matrix :math:`\mathbf{\Sigma}^{-1}`
or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
diagonal entries, such that
:math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
can be obtained via e.g. Cholesky decomposition of the covariance.
Example:
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
-0.2102
-0.5429
[torch.FloatTensor of size 2]
Args:
loc (Tensor): mean of the distribution
covariance_matrix (Tensor): positive-definite covariance matrix
precision_matrix (Tensor): positive-definite precision matrix
scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
Note:
Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
:attr:`scale_tril` can be specified.
Using :attr:`scale_tril` will be more efficient: all computations internally
are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
:attr:`precision_matrix` is passed instead, it is only used to compute
the corresponding lower triangular matrices using a Cholesky decomposition.
"""
arg_constraints = {'loc': constraints.real_vector,
'covariance_matrix': constraints.positive_definite,
'precision_matrix': constraints.positive_definite,
'scale_tril': constraints.lower_cholesky}
support = constraints.real
has_rsample = True
def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
event_shape = torch.Size(loc.shape[-1:])
if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
if scale_tril is not None:
if scale_tril.dim() < 2:
raise ValueError("scale_tril matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.scale_tril = scale_tril
batch_shape = _get_batch_shape(scale_tril, loc)
elif covariance_matrix is not None:
if covariance_matrix.dim() < 2:
raise ValueError("covariance_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.covariance_matrix = covariance_matrix
batch_shape = _get_batch_shape(covariance_matrix, loc)
else:
if precision_matrix.dim() < 2:
raise ValueError("precision_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.precision_matrix = precision_matrix
self.covariance_matrix = _batch_inverse(precision_matrix)
batch_shape = _get_batch_shape(precision_matrix, loc)
self.loc = loc
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
@lazy_property
[docs] def scale_tril(self):
return _batch_potrf_lower(self.covariance_matrix)
@lazy_property
[docs] def covariance_matrix(self):
return torch.matmul(self.scale_tril, self.scale_tril.transpose(-1, -2))
@lazy_property
[docs] def precision_matrix(self):
# TODO: use `torch.potri` on `scale_tril` once a backwards pass is implemented.
scale_tril_inv = _batch_inverse(self.scale_tril)
return torch.matmul(scale_tril_inv.transpose(-1, -2), scale_tril_inv)
@property
def mean(self):
return self.loc
@property
def variance(self):
n = self.covariance_matrix.size(-1)
var = torch.stack([cov.diag() for cov in self.covariance_matrix.view(-1, n, n)])
return var.view(self.covariance_matrix.size()[:-1])
[docs] def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = self.loc.new(*shape).normal_()
return self.loc + _batch_mv(self.scale_tril, eps)
[docs] def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diff = value - self.loc
M = _batch_mahalanobis(self.scale_tril, diff)
log_det = _batch_diag(self.scale_tril).abs().log().sum(-1)
return -0.5 * (M + self.loc.size(-1) * math.log(2 * math.pi)) - log_det
[docs] def entropy(self):
log_det = _batch_diag(self.scale_tril).abs().log().sum(-1)
H = 0.5 * (1.0 + math.log(2 * math.pi)) * self._event_shape[0] + log_det
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)