MultiheadAttention¶

class
torch.nn.
MultiheadAttention
(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]¶ Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need
$\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$ Parameters
embed_dim – total dimension of the model.
num_heads – parallel attention heads.
dropout – a Dropout layer on attn_output_weights. Default: 0.0.
bias – add bias as module parameter. Default: True.
add_bias_kv – add bias to the key and value sequences at dim=0.
add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.
kdim – total number of features in key. Default: None.
vdim – total number of features in value. Default: None.
Note – if kdim and vdim are None, they will be set to embed_dim such that
key, and value have the same number of features. (query,) –
Examples:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)

forward
(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]¶  Parameters
key, value (query,) – map a query and a set of keyvalue pairs to an output. See “Attention Is All You Need” for more details.
key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is nonzero, the corresponding value on the attention layer will be ignored
need_weights – output attn_output_weights.
attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.
 Shape:
Inputs:
query: $(L, N, E)$ where L is the target sequence length, N is the batch size, E is the embedding dimension.
key: $(S, N, E)$ , where S is the source sequence length, N is the batch size, E is the embedding dimension.
value: $(S, N, E)$ where S is the source sequence length, N is the batch size, E is the embedding dimension.
key_padding_mask: $(N, S)$ where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the nonzero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of
True
will be ignored while the position with the value ofFalse
will be unchanged.attn_mask: 2D mask $(L, S)$ where L is the target sequence length, S is the source sequence length. 3D mask $(N*num_heads, L, S)$ where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the nonzero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with
True
is not allowed to attend whileFalse
values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.Outputs:
attn_output: $(L, N, E)$ where L is the target sequence length, N is the batch size, E is the embedding dimension.
attn_output_weights: $(N, L, S)$ where N is the batch size, L is the target sequence length, S is the source sequence length.