torch.linalg.lu_solve¶
- torch.linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None) Tensor ¶
Computes the solution of a square system of linear equations with a unique solution given an LU decomposition.
Letting be or , this function computes the solution of the linear system associated to , which is defined as
where is given factorized as returned by
lu_factor()
.If
left
= False, this function returns the matrix that solves the systemIf
adjoint
= True (andleft
= True), given an LU factorization of this function function returns the that solves the systemwhere is the conjugate transpose when is complex, and the transpose when is real-valued. The
left
= False case is analogous.Supports inputs of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if the inputs are batches of matrices then the output has the same batch dimensions.
- Parameters
LU (Tensor) – tensor of shape (*, n, n) (or (*, k, k) if
left
= True) where * is zero or more batch dimensions as returned bylu_factor()
.pivots (Tensor) – tensor of shape (*, n) (or (*, k) if
left
= True) where * is zero or more batch dimensions as returned bylu_factor()
.B (Tensor) – right-hand side tensor of shape (*, n, k).
- Keyword Arguments
Examples:
>>> A = torch.randn(3, 3) >>> LU, pivots = torch.linalg.lu_factor(A) >>> B = torch.randn(3, 2) >>> X = torch.linalg.lu_solve(LU, pivots, B) >>> torch.allclose(A @ X, B) True >>> B = torch.randn(3, 3, 2) # Broadcasting rules apply: A is broadcasted >>> X = torch.linalg.lu_solve(LU, pivots, B) >>> torch.allclose(A @ X, B) True >>> B = torch.randn(3, 5, 3) >>> X = torch.linalg.lu_solve(LU, pivots, B, left=False) >>> torch.allclose(X @ A, B) True >>> B = torch.randn(3, 3, 4) # Now solve for A^T >>> X = torch.linalg.lu_solve(LU, pivots, B, adjoint=True) >>> torch.allclose(A.mT @ X, B) True