torch.solve¶
-
torch.
solve
(input, A, *, out=None)¶ This function returns the solution to the system of linear equations represented by and the LU factorization of A, in order as a namedtuple solution, LU.
LU contains L and U factors for LU factorization of A.
torch.solve(B, A) can take in 2D inputs B, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs solution, LU.
Supports real-valued and complex-valued inputs.
Warning
torch.solve()
is deprecated in favor oftorch.linalg.solve()
and will be removed in a future PyTorch release.torch.linalg.solve()
has its arguments reversed and does not return the LU factorization of the input. To get the LU factorization seetorch.lu()
, which may be used withtorch.lu_solve()
andtorch.lu_unpack()
.X = torch.solve(B, A).solution
should be replaced withX = torch.linalg.solve(A, B)
Note
Irrespective of the original strides, the returned matrices solution and LU will be transposed, i.e. with strides like B.contiguous().transpose(-1, -2).stride() and A.contiguous().transpose(-1, -2).stride() respectively.
- Parameters
- Keyword Arguments
Example:
>>> A = torch.tensor([[6.80, -2.11, 5.66, 5.97, 8.23], ... [-6.05, -3.30, 5.36, -4.44, 1.08], ... [-0.45, 2.58, -2.70, 0.27, 9.04], ... [8.32, 2.71, 4.35, -7.17, 2.14], ... [-9.67, -5.14, -7.26, 6.08, -6.87]]).t() >>> B = torch.tensor([[4.02, 6.19, -8.22, -7.57, -3.03], ... [-1.56, 4.00, -8.67, 1.75, 2.86], ... [9.81, -4.09, -4.57, -8.61, 8.99]]).t() >>> X, LU = torch.solve(B, A) >>> torch.dist(B, torch.mm(A, X)) tensor(1.00000e-06 * 7.0977) >>> # Batched solver example >>> A = torch.randn(2, 3, 1, 4, 4) >>> B = torch.randn(2, 3, 1, 4, 6) >>> X, LU = torch.solve(B, A) >>> torch.dist(B, A.matmul(X)) tensor(1.00000e-06 * 3.6386)