- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)¶
Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.
Multi-Head Attention is defined as:
forward()will use the optimized implementations of
In addition to support for the new
scaled_dot_product_attention()function, for speeding up Inference, MHA will use fastpath inference with support for Nested Tensors, iff:
self attention is being computed (i.e.,
valueare the same tensor.
inputs are batched (3D) with
Either autograd is disabled (using
torch.no_grad) or no tensor argument
training is disabled (using
Trueand the input is batched
vdimare equal to
if a NestedTensor is passed, neither
autocast is disabled
If the optimized inference fastpath implementation is in use, a NestedTensor can be passed for
valueto represent padding more efficiently than using a padding mask. In this case, a NestedTensor will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected.
embed_dim – Total dimension of the model.
num_heads – Number of parallel attention heads. Note that
embed_dimwill be split across
num_heads(i.e. each head will have dimension
embed_dim // num_heads).
dropout – Dropout probability on
bias – If specified, adds bias to input / output projection layers. Default:
add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default:
add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default:
kdim – Total number of features for keys. Default:
vdim – Total number of features for values. Default:
batch_first – If
True, then the input and output tensors are provided as (batch, seq, feature). Default:
False(seq, batch, feature).
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)¶
query (Tensor) – Query embeddings of shape for unbatched input, when
batch_first=True, where is the target sequence length, is the batch size, and is the query embedding dimension
embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.
key (Tensor) – Key embeddings of shape for unbatched input, when
batch_first=True, where is the source sequence length, is the batch size, and is the key embedding dimension
kdim. See “Attention Is All You Need” for more details.
value (Tensor) – Value embeddings of shape for unbatched input, when
batch_first=True, where is the source sequence length, is the batch size, and is the value embedding dimension
vdim. See “Attention Is All You Need” for more details.
key_padding_mask (Optional[Tensor]) – If specified, a mask of shape indicating which elements within
keyto ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be . Binary and float masks are supported. For a binary mask, a
Truevalue indicates that the corresponding
keyvalue will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding
need_weights (bool) – If specified, returns
attn_output_weightsin addition to
attn_mask (Optional[Tensor]) – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape or , where is the batch size, is the target sequence length, and is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. For a binary mask, a
Truevalue indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. If both attn_mask and key_padding_mask are supplied, their types should match.
is_causal (bool) – If specified, applies a causal mask as attention mask. Default:
is_causalprovides a hint that
attn_maskis the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
average_attn_weights (bool) – If true, indicates that the returned
attn_weightsshould be averaged across heads. Otherwise,
attn_weightsare provided separately per head. Note that this flag only has an effect when
True(i.e. average weights across heads)
- Return type:
attn_output - Attention outputs of shape when input is unbatched, when
batch_first=True, where is the target sequence length, is the batch size, and is the embedding dimension
attn_output_weights - Only returned when
average_attn_weights=True, returns attention weights averaged across heads of shape when input is unbatched or , where is the batch size, is the target sequence length, and is the source sequence length. If
average_attn_weights=False, returns attention weights per head of shape when input is unbatched or .
batch_first argument is ignored for unbatched inputs.
- merge_masks(attn_mask, key_padding_mask, query)¶
Determine mask type and combine masks if necessary. If only one mask is provided, that mask and the corresponding mask type will be returned. If both masks are provided, they will be both expanded to shape
(batch_size, num_heads, seq_len, seq_len), combined with logical
orand mask type 2 will be returned :param attn_mask: attention mask of shape
(seq_len, seq_len), mask type 0 :param key_padding_mask: padding mask of shape
(batch_size, seq_len), mask type 1 :param query: query embeddings of shape
(batch_size, seq_len, embed_dim)
merged mask mask_type: merged mask type (0, 1, or 2)
- Return type: