Shortcuts

torch.bmm

torch.bmm(input, mat2, *, out=None) Tensor

Performs a batch matrix-matrix product of matrices stored in input and mat2.

input and mat2 must be 3-D tensors each containing the same number of matrices.

If input is a (b×n×m)(b \times n \times m) tensor, mat2 is a (b×m×p)(b \times m \times p) tensor, out will be a (b×n×p)(b \times n \times p) tensor.

outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i

This operator supports TensorFloat32.

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Parameters:
  • input (Tensor) – the first batch of matrices to be multiplied

  • mat2 (Tensor) – the second batch of matrices to be multiplied

Keyword Arguments:

out (Tensor, optional) – the output tensor.

Example:

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources