torch.bmm¶
- torch.bmm(input, mat2, *, out=None) Tensor ¶
Performs a batch matrix-matrix product of matrices stored in
input
andmat2
.input
andmat2
must be 3-D tensors each containing the same number of matrices.If
input
is a tensor,mat2
is a tensor,out
will be a tensor.This operator supports TensorFloat32.
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
Note
This function does not broadcast. For broadcasting matrix products, see
torch.matmul()
.- Parameters
- Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
>>> input = torch.randn(10, 3, 4) >>> mat2 = torch.randn(10, 4, 5) >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5])