Shortcuts

torch.linalg.solve

torch.linalg.solve(A, B, *, out=None)Tensor

Computes the solution of a square system of linear equations with a unique solution.

Letting K\mathbb{K} be R\mathbb{R} or C\mathbb{C}, this function computes the solution XKn×kX \in \mathbb{K}^{n \times k} of the linear system associated to AKn×n,BKm×kA \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{m \times k}, which is defined as

AX=BAX = B

This system of linear equations has one solution if and only if AA is invertible. This function assumes that AA is invertible.

Supports inputs of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if the inputs are batches of matrices then the output has the same batch dimensions.

Letting * be zero or more batch dimensions,

  • If A has shape (*, n, n) and B has shape (*, n) (a batch of vectors) or shape (*, n, k) (a batch of matrices or “multiple right-hand sides”), this function returns X of shape (*, n) or (*, n, k) respectively.

  • Otherwise, if A has shape (*, n, n) and B has shape (n,) or (n, k), B is broadcasted to have shape (*, n) or (*, n, k) respectively. This function then returns the solution of the resulting batch of systems of linear equations.

Note

This function computes X = A.inverse() @ B in a faster and more numerically stable way than performing the computations separately.

Note

When inputs are on a CUDA device, this function synchronizes that device with the CPU.

Parameters
  • A (Tensor) – tensor of shape (*, n, n) where * is zero or more batch dimensions.

  • B (Tensor) – right-hand side tensor of shape (*, n) or (*, n, k) or (n,) or (n, k) according to the rules described above

Keyword Arguments

out (Tensor, optional) – output tensor. Ignored if None. Default: None.

Raises

RuntimeError – if the A matrix is not invertible or any matrix in a batched A is not invertible.

Examples:

>>> A = torch.randn(3, 3)
>>> b = torch.randn(3)
>>> x = torch.linalg.solve(A, b)
>>> torch.allclose(A @ x, b)
True
>>> A = torch.randn(2, 3, 3)
>>> B = torch.randn(2, 3, 4)
>>> X = torch.linalg.solve(A, B)
>>> X.shape
torch.Size([2, 3, 4])
>>> torch.allclose(A @ X, B)
True

>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(3, 1)
>>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1)
>>> x.shape
torch.Size([2, 3, 1])
>>> torch.allclose(A @ x, b)
True
>>> b = torch.randn(3)
>>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3)
>>> x.shape
torch.Size([2, 3])
>>> Ax = A @ x.unsqueeze(-1)
>>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax))
True

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources