Source code for torch._lowrank
"""Implement various linear algebra algorithms for low rank matrices."""
__all__ = ["svd_lowrank", "pca_lowrank"]
from typing import Optional, Tuple
import torch
from torch import _linalg_utils as _utils, Tensor
from torch.overrides import handle_torch_function, has_torch_function
def get_approximate_basis(
A: Tensor,
q: int,
niter: Optional[int] = 2,
M: Optional[Tensor] = None,
) -> Tensor:
"""Return tensor :math:`Q` with :math:`q` orthonormal columns such
that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
approximates :math:`A - M`. without instantiating any tensors
of the size of :math:`A` or :math:`M`.
.. note:: The implementation is based on the Algorithm 4.4 from
Halko et al., 2009.
.. note:: For an adequate approximation of a k-rank matrix
:math:`A`, where k is not known in advance but could be
estimated, the number of :math:`Q` columns, q, can be
choosen according to the following criteria: in general,
:math:`k <= q <= min(2*k, m, n)`. For large low-rank
matrices, take :math:`q = k + 5..10`. If k is
relatively small compared to :math:`min(m, n)`, choosing
:math:`q = k + 0..2` may be sufficient.
.. note:: To obtain repeatable results, reset the seed for the
pseudorandom number generator
Args::
A (Tensor): the input tensor of size :math:`(*, m, n)`
q (int): the dimension of subspace spanned by :math:`Q`
columns.
niter (int, optional): the number of subspace iterations to
conduct; ``niter`` must be a
nonnegative integer. In most cases, the
default value 2 is more than enough.
M (Tensor, optional): the input tensor's mean of size
:math:`(*, m, n)`.
References::
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
structure with randomness: probabilistic algorithms for
constructing approximate matrix decompositions,
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
`arXiv <http://arxiv.org/abs/0909.4061>`_).
"""
niter = 2 if niter is None else niter
dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
matmul = _utils.matmul
R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
# The following code could be made faster using torch.geqrf + torch.ormqr
# but geqrf is not differentiable
X = matmul(A, R)
if M is not None:
X = X - matmul(M, R)
Q = torch.linalg.qr(X).Q
for _ in range(niter):
X = matmul(A.mH, Q)
if M is not None:
X = X - matmul(M.mH, Q)
Q = torch.linalg.qr(X).Q
X = matmul(A, Q)
if M is not None:
X = X - matmul(M, Q)
Q = torch.linalg.qr(X).Q
return Q
[docs]def svd_lowrank(
A: Tensor,
q: Optional[int] = 6,
niter: Optional[int] = 2,
M: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
batches of matrices, or a sparse matrix :math:`A` such that
:math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
SVD is computed for the matrix :math:`A - M`.
.. note:: The implementation is based on the Algorithm 5.1 from
Halko et al., 2009.
.. note:: For an adequate approximation of a k-rank matrix
:math:`A`, where k is not known in advance but could be
estimated, the number of :math:`Q` columns, q, can be
choosen according to the following criteria: in general,
:math:`k <= q <= min(2*k, m, n)`. For large low-rank
matrices, take :math:`q = k + 5..10`. If k is
relatively small compared to :math:`min(m, n)`, choosing
:math:`q = k + 0..2` may be sufficient.
.. note:: This is a randomized method. To obtain repeatable results,
set the seed for the pseudorandom number generator
.. note:: In general, use the full-rank SVD implementation
:func:`torch.linalg.svd` for dense matrices due to its 10x
higher performance characteristics. The low-rank SVD
will be useful for huge sparse matrices that
:func:`torch.linalg.svd` cannot handle.
Args::
A (Tensor): the input tensor of size :math:`(*, m, n)`
q (int, optional): a slightly overestimated rank of A.
niter (int, optional): the number of subspace iterations to
conduct; niter must be a nonnegative
integer, and defaults to 2
M (Tensor, optional): the input tensor's mean of size
:math:`(*, m, n)`, which will be broadcasted
to the size of A in this function.
References::
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
structure with randomness: probabilistic algorithms for
constructing approximate matrix decompositions,
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
`arXiv <https://arxiv.org/abs/0909.4061>`_).
"""
if not torch.jit.is_scripting():
tensor_ops = (A, M)
if not set(map(type, tensor_ops)).issubset(
(torch.Tensor, type(None))
) and has_torch_function(tensor_ops):
return handle_torch_function(
svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
)
return _svd_lowrank(A, q=q, niter=niter, M=M)
def _svd_lowrank(
A: Tensor,
q: Optional[int] = 6,
niter: Optional[int] = 2,
M: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
# Algorithm 5.1 in Halko et al., 2009
q = 6 if q is None else q
m, n = A.shape[-2:]
matmul = _utils.matmul
if M is not None:
M = M.broadcast_to(A.size())
# Assume that A is tall
if m < n:
A = A.mH
if M is not None:
M = M.mH
Q = get_approximate_basis(A, q, niter=niter, M=M)
B = matmul(Q.mH, A)
if M is not None:
B = B - matmul(Q.mH, M)
U, S, Vh = torch.linalg.svd(B, full_matrices=False)
V = Vh.mH
U = Q.matmul(U)
if m < n:
U, V = V, U
return U, S, V
[docs]def pca_lowrank(
A: Tensor,
q: Optional[int] = None,
center: bool = True,
niter: int = 2,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Performs linear Principal Component Analysis (PCA) on a low-rank
matrix, batches of such matrices, or sparse matrix.
This function returns a namedtuple ``(U, S, V)`` which is the
nearly optimal approximation of a singular value decomposition of
a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
.. note:: The relation of ``(U, S, V)`` to PCA is as follows:
- :math:`A` is a data matrix with ``m`` samples and
``n`` features
- the :math:`V` columns represent the principal directions
- :math:`S ** 2 / (m - 1)` contains the eigenvalues of
:math:`A^T A / (m - 1)` which is the covariance of
``A`` when ``center=True`` is provided.
- ``matmul(A, V[:, :k])`` projects data to the first k
principal components
.. note:: Different from the standard SVD, the size of returned
matrices depend on the specified rank and q
values as follows:
- :math:`U` is m x q matrix
- :math:`S` is q-vector
- :math:`V` is n x q matrix
.. note:: To obtain repeatable results, reset the seed for the
pseudorandom number generator
Args:
A (Tensor): the input tensor of size :math:`(*, m, n)`
q (int, optional): a slightly overestimated rank of
:math:`A`. By default, ``q = min(6, m,
n)``.
center (bool, optional): if True, center the input tensor,
otherwise, assume that the input is
centered.
niter (int, optional): the number of subspace iterations to
conduct; niter must be a nonnegative
integer, and defaults to 2.
References::
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
structure with randomness: probabilistic algorithms for
constructing approximate matrix decompositions,
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
`arXiv <http://arxiv.org/abs/0909.4061>`_).
"""
if not torch.jit.is_scripting():
if type(A) is not torch.Tensor and has_torch_function((A,)):
return handle_torch_function(
pca_lowrank, (A,), A, q=q, center=center, niter=niter
)
(m, n) = A.shape[-2:]
if q is None:
q = min(6, m, n)
elif not (q >= 0 and q <= min(m, n)):
raise ValueError(
f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
)
if not (niter >= 0):
raise ValueError(f"niter(={niter}) must be non-negative integer")
dtype = _utils.get_floating_dtype(A)
if not center:
return _svd_lowrank(A, q, niter=niter, M=None)
if _utils.is_sparse(A):
if len(A.shape) != 2:
raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
c = torch.sparse.sum(A, dim=(-2,)) / m
# reshape c
column_indices = c.indices()[0]
indices = torch.zeros(
2,
len(column_indices),
dtype=column_indices.dtype,
device=column_indices.device,
)
indices[0] = column_indices
C_t = torch.sparse_coo_tensor(
indices, c.values(), (n, 1), dtype=dtype, device=A.device
)
ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
M = torch.sparse.mm(C_t, ones_m1_t).mT
return _svd_lowrank(A, q, niter=niter, M=M)
else:
C = A.mean(dim=(-2,), keepdim=True)
return _svd_lowrank(A - C, q, niter=niter, M=None)