• Docs >
  • torchtext.nn.modules.multiheadattention
Shortcuts

torchtext.nn.modules.multiheadattention

MultiheadAttentionContainer

class torchtext.nn.modules.multiheadattention.MultiheadAttentionContainer(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]
__init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]

A multi-head attention container

Parameters
  • nhead – the number of heads in the multiheadattention model

  • in_proj_container – A container of multi-head in-projection linear layers (a.k.a nn.Linear).

  • attention_layer – The custom attention layer. The input sent from MHA container to the attention layer is in the shape of (…, L, N * H, E / H) for query and (…, S, N * H, E / H) for key/value while the output shape of the attention layer is expected to be (…, L, N * H, E / H). The attention_layer needs to support broadcast if users want the overall MultiheadAttentionContainer with broadcast.

  • out_proj – The multi-head out-projection layer (a.k.a nn.Linear).

  • batch_first – If True, then the input and output tensors are provided as (…, N, L, E). Default: False

Examples::
>>> import torch
>>> embed_dim, num_heads, bsz = 10, 5, 64
>>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim))
>>> MHA = MultiheadAttentionContainer(num_heads,
                                      in_proj_container,
                                      ScaledDotProduct(),
                                      torch.nn.Linear(embed_dim, embed_dim))
>>> query = torch.rand((21, bsz, embed_dim))
>>> key = value = torch.rand((16, bsz, embed_dim))
>>> attn_output, attn_weights = MHA(query, key, value)
>>> print(attn_output.shape)
>>> torch.Size([21, 64, 10])
forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, bias_k: Optional[torch.Tensor] = None, bias_v: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]
Parameters
  • query, key, value – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • bias_k and bias_v (attn_mask,) – keyword arguments passed to the attention layer. See the definitions in the attention.

Shape:
  • Inputs:

  • query: \((..., L, N, E)\)

  • key: \((..., S, N, E)\)

  • value: \((..., S, N, E)\)

  • attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer.

  • Outputs:

  • attn_output: \((..., L, N, E)\)

  • attn_output_weights: \((N * H, L, S)\)

Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose).

The MultiheadAttentionContainer module will operate on the last three dimensions.

where where L is the target length, S is the sequence length, H is the number of attention heads,

N is the batch size, and E is the embedding dimension.

InProjContainer

class torchtext.nn.modules.multiheadattention.InProjContainer(query_proj, key_proj, value_proj)[source]
__init__(query_proj, key_proj, value_proj)[source]

A in-proj container to project query/key/value in MultiheadAttention. This module happens before reshaping the projected query/key/value into multiple heads. See the linear layers (bottom) of Multi-head Attention in Fig 2 of Attention Is All You Need paper. Also check the usage example in torchtext.nn.MultiheadAttentionContainer.

Parameters
  • query_proj – a proj layer for query. A typical projection layer is torch.nn.Linear.

  • key_proj – a proj layer for key. A typical projection layer is torch.nn.Linear.

  • value_proj – a proj layer for value. A typical projection layer is torch.nn.Linear.

forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Projects the input sequences using in-proj layers. query/key/value are simply passed to the forward func of query/key/value_proj, respectively.

Parameters

key, value (query,) – sequence to be projected

Examples::
>>> from torchtext.nn import InProjContainer
>>> embed_dim, bsz = 10, 64
>>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim))
>>> q = torch.rand((5, bsz, embed_dim))
>>> k = v = torch.rand((6, bsz, embed_dim))
>>> q, k, v = in_proj_container(q, k, v)

ScaledDotProduct

class torchtext.nn.modules.multiheadattention.ScaledDotProduct(dropout=0.0, batch_first=False)[source]
__init__(dropout=0.0, batch_first=False)[source]

Processes a projected query and key-value pair to apply scaled dot product attention.

Parameters
  • dropout (float) – probability of dropping an attention weight.

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False

Examples::
>>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)
>>> q = torch.randn(21, 256, 3)
>>> k = v = torch.randn(21, 256, 3)
>>> attn_output, attn_weights = SDP(q, k, v)
>>> print(attn_output.shape, attn_weights.shape)
torch.Size([21, 256, 3]) torch.Size([256, 21, 21])
forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, bias_k: Optional[torch.Tensor] = None, bias_v: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]

Uses a scaled dot product with the projected key-value pair to update the projected query.

Parameters
  • query (Tensor) – Projected query

  • key (Tensor) – Projected key

  • value (Tensor) – Projected value

  • attn_mask (BoolTensor, optional) – 3D mask that prevents attention to certain positions.

  • and bias_v (bias_k) – (Tensor, optional): one more key and value sequence to be added at sequence dim (dim=-3). Those are used for incremental decoding. Users should provide non-None to both arguments in order to activate them.

Shape:
  • query: \((..., L, N * H, E / H)\)

  • key: \((..., S, N * H, E / H)\)

  • value: \((..., S, N * H, E / H)\)

  • attn_mask: \((N * H, L, S)\), positions with True are not allowed to attend

    while False values will be unchanged.

  • bias_k and bias_v:bias: \((1, N * H, E / H)\)

  • Output: \((..., L, N * H, E / H)\), \((N * H, L, S)\)

Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose).

The ScaledDotProduct module will operate on the last three dimensions.

where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources