torch.linalg.lu¶
- torch.linalg.lu(A, *, pivot=True, out=None)¶
Computes the LU decomposition with partial pivoting of a matrix.
Letting be or , the LU decomposition with partial pivoting of a matrix is defined as
where k = min(m,n), is a permutation matrix, is lower triangular with ones on the diagonal and is upper triangular.
If
pivot
= False andA
is on GPU, then the LU decomposition without pivoting is computedWhen
pivot
= False, the returned matrixP
will be empty. The LU decomposition without pivoting may not exist if any of the principal minors ofA
is singular. In this case, the output matrix may contain inf or NaN.Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if
A
is a batch of matrices then the output has the same batch dimensions.See also
torch.linalg.solve()
solves a system of linear equations using the LU decomposition with partial pivoting.Warning
The LU decomposition is almost never unique, as often there are different permutation matrices that can yield different LU decompositions. As such, different platforms, like SciPy, or inputs on different devices, may produce different valid decompositions.
Warning
Gradient computations are only supported if the input matrix is full-rank. If this condition is not met, no error will be thrown, but the gradient may not be finite. This is because the LU decomposition with pivoting is not differentiable at these points.
- Parameters
- Keyword Arguments
out (tuple, optional) – output tuple of three tensors. Ignored if None. Default: None.
- Returns
A named tuple (P, L, U).
Examples:
>>> A = torch.randn(3, 2) >>> P, L, U = torch.linalg.lu(A) >>> P tensor([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]) >>> L tensor([[1.0000, 0.0000], [0.5007, 1.0000], [0.0633, 0.9755]]) >>> U tensor([[0.3771, 0.0489], [0.0000, 0.9644]]) >>> torch.dist(A, P @ L @ U) tensor(5.9605e-08) >>> A = torch.randn(2, 5, 7, device="cuda") >>> P, L, U = torch.linalg.lu(A, pivot=False) >>> P tensor([], device='cuda:0') >>> torch.dist(A, L @ U) tensor(1.0376e-06, device='cuda:0')