class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]

Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need.

$\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O$

where $head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$.

Parameters
• embed_dim – Total dimension of the model.

• num_heads – Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

• dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).

• bias – If specified, adds bias to input / output projection layers. Default: True.

• add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.

• add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.

• kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).

• vdim – Total number of features for values. Default: None (uses vdim=embed_dim).

• batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]
Parameters
• query – Query embeddings of shape $(L, N, E_q)$ when batch_first=False or $(N, L, E_q)$ when batch_first=True, where $L$ is the target sequence length, $N$ is the batch size, and $E_q$ is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.

• key – Key embeddings of shape $(S, N, E_k)$ when batch_first=False or $(N, S, E_k)$ when batch_first=True, where $S$ is the source sequence length, $N$ is the batch size, and $E_k$ is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

• value – Value embeddings of shape $(S, N, E_v)$ when batch_first=False or $(N, S, E_v)$ when batch_first=True, where $S$ is the source sequence length, $N$ is the batch size, and $E_v$ is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

• key_padding_mask – If specified, a mask of shape $(N, S)$ indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). Binary and byte masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.

• need_weights – If specified, returns attn_output_weights in addition to attn_outputs. Default: True.

• attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape $(L, S)$ or $(N\cdot\text{num\_heads}, L, S)$, where $N$ is the batch size, $L$ is the target sequence length, and $S$ is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.

Outputs:
• attn_output - Attention outputs of shape $(L, N, E)$ when batch_first=False or $(N, L, E)$ when batch_first=True, where $L$ is the target sequence length, $N$ is the batch size, and $E$ is the embedding dimension embed_dim.

• attn_output_weights - Attention output weights of shape $(N, L, S)$, where $N$ is the batch size, $L$ is the target sequence length, and $S$ is the source sequence length. Only returned when need_weights=True.