"""Implement various linear algebra algorithms for low rank matrices."""__all__=["svd_lowrank","pca_lowrank"]fromtypingimportOptional,Tupleimporttorchfromtorchimport_linalg_utilsas_utils,Tensorfromtorch.overridesimporthandle_torch_function,has_torch_functiondefget_approximate_basis(A:Tensor,q:int,niter:Optional[int]=2,M:Optional[Tensor]=None,)->Tensor:"""Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` approximates :math:`A - M`. without instantiating any tensors of the size of :math:`A` or :math:`M`. .. note:: The implementation is based on the Algorithm 4.4 from Halko et al., 2009. .. note:: For an adequate approximation of a k-rank matrix :math:`A`, where k is not known in advance but could be estimated, the number of :math:`Q` columns, q, can be choosen according to the following criteria: in general, :math:`k <= q <= min(2*k, m, n)`. For large low-rank matrices, take :math:`q = k + 5..10`. If k is relatively small compared to :math:`min(m, n)`, choosing :math:`q = k + 0..2` may be sufficient. .. note:: To obtain repeatable results, reset the seed for the pseudorandom number generator Args:: A (Tensor): the input tensor of size :math:`(*, m, n)` q (int): the dimension of subspace spanned by :math:`Q` columns. niter (int, optional): the number of subspace iterations to conduct; ``niter`` must be a nonnegative integer. In most cases, the default value 2 is more than enough. M (Tensor, optional): the input tensor's mean of size :math:`(*, m, n)`. References:: - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding structure with randomness: probabilistic algorithms for constructing approximate matrix decompositions, arXiv:0909.4061 [math.NA; math.PR], 2009 (available at `arXiv <http://arxiv.org/abs/0909.4061>`_). """niter=2ifniterisNoneelseniterdtype=_utils.get_floating_dtype(A)ifnotA.is_complex()elseA.dtypematmul=_utils.matmulR=torch.randn(A.shape[-1],q,dtype=dtype,device=A.device)# The following code could be made faster using torch.geqrf + torch.ormqr# but geqrf is not differentiableX=matmul(A,R)ifMisnotNone:X=X-matmul(M,R)Q=torch.linalg.qr(X).Qforiinrange(niter):X=matmul(A.mH,Q)ifMisnotNone:X=X-matmul(M.mH,Q)Q=torch.linalg.qr(X).QX=matmul(A,Q)ifMisnotNone:X=X-matmul(M,Q)Q=torch.linalg.qr(X).QreturnQ
[docs]defsvd_lowrank(A:Tensor,q:Optional[int]=6,niter:Optional[int]=2,M:Optional[Tensor]=None,)->Tuple[Tensor,Tensor,Tensor]:r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then SVD is computed for the matrix :math:`A - M`. .. note:: The implementation is based on the Algorithm 5.1 from Halko et al., 2009. .. note:: For an adequate approximation of a k-rank matrix :math:`A`, where k is not known in advance but could be estimated, the number of :math:`Q` columns, q, can be choosen according to the following criteria: in general, :math:`k <= q <= min(2*k, m, n)`. For large low-rank matrices, take :math:`q = k + 5..10`. If k is relatively small compared to :math:`min(m, n)`, choosing :math:`q = k + 0..2` may be sufficient. .. note:: This is a randomized method. To obtain repeatable results, set the seed for the pseudorandom number generator .. note:: In general, use the full-rank SVD implementation :func:`torch.linalg.svd` for dense matrices due to its 10x higher performance characteristics. The low-rank SVD will be useful for huge sparse matrices that :func:`torch.linalg.svd` cannot handle. Args:: A (Tensor): the input tensor of size :math:`(*, m, n)` q (int, optional): a slightly overestimated rank of A. niter (int, optional): the number of subspace iterations to conduct; niter must be a nonnegative integer, and defaults to 2 M (Tensor, optional): the input tensor's mean of size :math:`(*, m, n)`, which will be broadcasted to the size of A in this function. References:: - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding structure with randomness: probabilistic algorithms for constructing approximate matrix decompositions, arXiv:0909.4061 [math.NA; math.PR], 2009 (available at `arXiv <https://arxiv.org/abs/0909.4061>`_). """ifnottorch.jit.is_scripting():tensor_ops=(A,M)ifnotset(map(type,tensor_ops)).issubset((torch.Tensor,type(None)))andhas_torch_function(tensor_ops):returnhandle_torch_function(svd_lowrank,tensor_ops,A,q=q,niter=niter,M=M)return_svd_lowrank(A,q=q,niter=niter,M=M)
def_svd_lowrank(A:Tensor,q:Optional[int]=6,niter:Optional[int]=2,M:Optional[Tensor]=None,)->Tuple[Tensor,Tensor,Tensor]:# Algorithm 5.1 in Halko et al., 2009q=6ifqisNoneelseqm,n=A.shape[-2:]matmul=_utils.matmulifMisnotNone:M=M.broadcast_to(A.size())# Assume that A is tallifm<n:A=A.mHifMisnotNone:M=M.mHQ=get_approximate_basis(A,q,niter=niter,M=M)B=matmul(Q.mH,A)ifMisnotNone:B=B-matmul(Q.mH,M)U,S,Vh=torch.linalg.svd(B,full_matrices=False)V=Vh.mHU=Q.matmul(U)ifm<n:U,V=V,UreturnU,S,V
[docs]defpca_lowrank(A:Tensor,q:Optional[int]=None,center:bool=True,niter:int=2,)->Tuple[Tensor,Tensor,Tensor]:r"""Performs linear Principal Component Analysis (PCA) on a low-rank matrix, batches of such matrices, or sparse matrix. This function returns a namedtuple ``(U, S, V)`` which is the nearly optimal approximation of a singular value decomposition of a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` .. note:: The relation of ``(U, S, V)`` to PCA is as follows: - :math:`A` is a data matrix with ``m`` samples and ``n`` features - the :math:`V` columns represent the principal directions - :math:`S ** 2 / (m - 1)` contains the eigenvalues of :math:`A^T A / (m - 1)` which is the covariance of ``A`` when ``center=True`` is provided. - ``matmul(A, V[:, :k])`` projects data to the first k principal components .. note:: Different from the standard SVD, the size of returned matrices depend on the specified rank and q values as follows: - :math:`U` is m x q matrix - :math:`S` is q-vector - :math:`V` is n x q matrix .. note:: To obtain repeatable results, reset the seed for the pseudorandom number generator Args: A (Tensor): the input tensor of size :math:`(*, m, n)` q (int, optional): a slightly overestimated rank of :math:`A`. By default, ``q = min(6, m, n)``. center (bool, optional): if True, center the input tensor, otherwise, assume that the input is centered. niter (int, optional): the number of subspace iterations to conduct; niter must be a nonnegative integer, and defaults to 2. References:: - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding structure with randomness: probabilistic algorithms for constructing approximate matrix decompositions, arXiv:0909.4061 [math.NA; math.PR], 2009 (available at `arXiv <http://arxiv.org/abs/0909.4061>`_). """ifnottorch.jit.is_scripting():iftype(A)isnottorch.Tensorandhas_torch_function((A,)):returnhandle_torch_function(pca_lowrank,(A,),A,q=q,center=center,niter=niter)(m,n)=A.shape[-2:]ifqisNone:q=min(6,m,n)elifnot(q>=0andq<=min(m,n)):raiseValueError(f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m,n)}")ifnot(niter>=0):raiseValueError(f"niter(={niter}) must be non-negative integer")dtype=_utils.get_floating_dtype(A)ifnotcenter:return_svd_lowrank(A,q,niter=niter,M=None)if_utils.is_sparse(A):iflen(A.shape)!=2:raiseValueError("pca_lowrank input is expected to be 2-dimensional tensor")c=torch.sparse.sum(A,dim=(-2,))/m# reshape ccolumn_indices=c.indices()[0]indices=torch.zeros(2,len(column_indices),dtype=column_indices.dtype,device=column_indices.device,)indices[0]=column_indicesC_t=torch.sparse_coo_tensor(indices,c.values(),(n,1),dtype=dtype,device=A.device)ones_m1_t=torch.ones(A.shape[:-2]+(1,m),dtype=dtype,device=A.device)M=torch.sparse.mm(C_t,ones_m1_t).mTreturn_svd_lowrank(A,q,niter=niter,M=M)else:C=A.mean(dim=(-2,),keepdim=True)return_svd_lowrank(A-C,q,niter=niter,M=None)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.