Shortcuts

MultiheadAttention

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]

Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need.

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

Parameters
  • embed_dim – Total dimension of the model.

  • num_heads – Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

  • dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).

  • bias – If specified, adds bias to input / output projection layers. Default: True.

  • add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.

  • add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.

  • kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).

  • vdim – Total number of features for values. Default: None (uses vdim=embed_dim).

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]
Parameters
  • query – Query embeddings of shape (L,N,Eq)(L, N, E_q) when batch_first=False or (N,L,Eq)(N, L, E_q) when batch_first=True, where LL is the target sequence length, NN is the batch size, and EqE_q is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.

  • key – Key embeddings of shape (S,N,Ek)(S, N, E_k) when batch_first=False or (N,S,Ek)(N, S, E_k) when batch_first=True, where SS is the source sequence length, NN is the batch size, and EkE_k is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

  • value – Value embeddings of shape (S,N,Ev)(S, N, E_v) when batch_first=False or (N,S,Ev)(N, S, E_v) when batch_first=True, where SS is the source sequence length, NN is the batch size, and EvE_v is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

  • key_padding_mask – If specified, a mask of shape (N,S)(N, S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). Binary and byte masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.

  • need_weights – If specified, returns attn_output_weights in addition to attn_outputs. Default: True.

  • attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L,S)(L, S) or (Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.

Outputs:
  • attn_output - Attention outputs of shape (L,N,E)(L, N, E) when batch_first=False or (N,L,E)(N, L, E) when batch_first=True, where LL is the target sequence length, NN is the batch size, and EE is the embedding dimension embed_dim.

  • attn_output_weights - Attention output weights of shape (N,L,S)(N, L, S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. Only returned when need_weights=True.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources