Shortcuts

# torch.lu_solve¶

torch.lu_solve(b, LU_data, LU_pivots, *, out=None)

Returns the LU solve of the linear system $Ax = b$ using the partially pivoted LU factorization of A from lu_factor().

This function supports float, double, cfloat and cdouble dtypes for input.

Warning

torch.lu_solve() is deprecated in favor of torch.linalg.lu_solve(). torch.lu_solve() will be removed in a future PyTorch release. X = torch.lu_solve(B, LU, pivots) should be replaced with

X = linalg.lu_solve(LU, pivots, B)

Parameters:
• b (Tensor) – the RHS tensor of size $(*, m, k)$, where $*$ is zero or more batch dimensions.

• LU_data (Tensor) – the pivoted LU factorization of A from lu_factor() of size $(*, m, m)$, where $*$ is zero or more batch dimensions.

• LU_pivots (IntTensor) – the pivots of the LU factorization from lu_factor() of size $(*, m)$, where $*$ is zero or more batch dimensions. The batch dimensions of LU_pivots must be equal to the batch dimensions of LU_data.

Keyword Arguments:

out (Tensor, optional) – the output tensor.

Example:

>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3, 1)
>>> LU, pivots = torch.linalg.lu_factor(A)
>>> x = torch.lu_solve(b, LU, pivots)
>>> torch.dist(A @ x, b)
tensor(1.00000e-07 *
2.8312)


## Docs

Access comprehensive developer documentation for PyTorch

View Docs

## Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials