Shortcuts

# torch.solve¶

torch.solve(input, A, *, out=None)

This function returns the solution to the system of linear equations represented by $AX = B$ and the LU factorization of A, in order as a namedtuple solution, LU.

LU contains L and U factors for LU factorization of A.

torch.solve(B, A) can take in 2D inputs B, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs solution, LU.

Supports real-valued and complex-valued inputs.

Warning

torch.solve() is deprecated in favor of torch.linalg.solve() and will be removed in a future PyTorch release. torch.linalg.solve() has its arguments reversed and does not return the LU factorization of the input. To get the LU factorization see torch.lu(), which may be used with torch.lu_solve() and torch.lu_unpack().

X = torch.solve(B, A).solution should be replaced with

X = torch.linalg.solve(A, B)


Note

Irrespective of the original strides, the returned matrices solution and LU will be transposed, i.e. with strides like B.contiguous().mT.stride() and A.contiguous().mT.stride() respectively.

Parameters
• input (Tensor) – input matrix $B$ of size $(*, m, k)$ , where $*$ is zero or more batch dimensions.

• A (Tensor) – input square matrix of size $(*, m, m)$, where $*$ is zero or more batch dimensions.

Keyword Arguments

out ((Tensor, Tensor), optional) – optional output tuple.

Example:

>>> A = torch.tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
...                   [-6.05, -3.30,  5.36, -4.44,  1.08],
...                   [-0.45,  2.58, -2.70,  0.27,  9.04],
...                   [8.32,  2.71,  4.35,  -7.17,  2.14],
...                   [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
...                   [-1.56,  4.00, -8.67,  1.75,  2.86],
...                   [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, torch.mm(A, X))
tensor(1.00000e-06 *
7.0977)

>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *
3.6386)


## Docs

Access comprehensive developer documentation for PyTorch

View Docs

## Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

## Resources

Find development resources and get your questions answered

View Resources