torch.linalg.tensorsolve¶
-
torch.linalg.
tensorsolve
(A, B, dims=None, *, out=None) → Tensor¶ Computes the solution X to the system torch.tensordot(A, X) = B.
If m is the product of the first
B
.ndim dimensions ofA
and n is the product of the rest of the dimensions, this function expects m and n to be equal.The returned tensor x satisfies tensordot(
A
, x, dims=x.ndim) ==B
. x has shapeA
[B.ndim:].If
dims
is specified,A
will be reshaped asA = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))
Supports inputs of float, double, cfloat and cdouble dtypes.
See also
torch.linalg.tensorinv()
computes the multiplicative inverse oftorch.tensordot()
.- Parameters
- Keyword Arguments
out (Tensor, optional) – output tensor. Ignored if None. Default: None.
- Raises
RuntimeError – if the reshaped
A
.view(m, m) with m as above is not invertible or the product of the firstind
dimensions is not equal to the product of the rest of the dimensions.
Examples:
>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) >>> B = torch.randn(2 * 3, 4) >>> X = torch.linalg.tensorsolve(A, B) >>> X.shape torch.Size([2, 3, 4]) >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) True >>> A = torch.randn(6, 4, 4, 3, 2) >>> B = torch.randn(4, 3, 2) >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) >>> X.shape torch.Size([6, 4]) >>> A = A.permute(1, 3, 4, 0, 2) >>> A.shape[B.ndim:] torch.Size([6, 4]) >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) True