MultiheadAttention¶

class
torch.nn.
MultiheadAttention
(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]¶ Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need
$\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O$where $head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$.
 Parameters
embed_dim – total dimension of the model.
num_heads – parallel attention heads.
dropout – a Dropout layer on attn_output_weights. Default: 0.0.
bias – add bias as module parameter. Default: True.
add_bias_kv – add bias to the key and value sequences at dim=0.
add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.
kdim – total number of features in key. Default: None.
vdim – total number of features in value. Default: None.
batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature). Default:False
(seq, batch, feature).
Note that if
kdim
andvdim
are None, they will be set toembed_dim
such that query, key, and value have the same number of features.Examples:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)

forward
(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]¶  Parameters
key, value (query,) – map a query and a set of keyvalue pairs to an output. See “Attention Is All You Need” for more details.
key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is nonzero, the corresponding value on the attention layer will be ignored
need_weights – output attn_output_weights.
attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.
 Shapes for inputs:
query: $(L, N, E)$ where L is the target sequence length, N is the batch size, E is the embedding dimension. $(N, L, E)$ if
batch_first
isTrue
.key: $(S, N, E)$, where S is the source sequence length, N is the batch size, E is the embedding dimension. $(N, S, E)$ if
batch_first
isTrue
.value: $(S, N, E)$ where S is the source sequence length, N is the batch size, E is the embedding dimension. $(N, S, E)$ if
batch_first
isTrue
.key_padding_mask: $(N, S)$ where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the nonzero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of
True
will be ignored while the position with the value ofFalse
will be unchanged.attn_mask: if a 2D mask: $(L, S)$ where L is the target sequence length, S is the source sequence length.
If a 3D mask: $(N\cdot\text{num\_heads}, L, S)$ where N is the batch size, L is the target sequence length, S is the source sequence length.
attn_mask
ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the nonzero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions withTrue
is not allowed to attend whileFalse
values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
 Shapes for outputs:
attn_output: $(L, N, E)$ where L is the target sequence length, N is the batch size, E is the embedding dimension. $(N, L, E)$ if
batch_first
isTrue
.attn_output_weights: $(N, L, S)$ where N is the batch size, L is the target sequence length, S is the source sequence length.