torchtext.nn.modules.multiheadattention¶
MultiheadAttentionContainer¶
-
class
torchtext.nn.modules.multiheadattention.
MultiheadAttentionContainer
(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]¶ -
__init__
(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]¶ A multi-head attention container
- Parameters
nhead – the number of heads in the multiheadattention model
in_proj_container – A container of multi-head in-projection linear layers (a.k.a nn.Linear).
attention_layer – The custom attention layer. The input sent from MHA container to the attention layer is in the shape of (…, L, N * H, E / H) for query and (…, S, N * H, E / H) for key/value while the output shape of the attention layer is expected to be (…, L, N * H, E / H). The attention_layer needs to support broadcast if users want the overall MultiheadAttentionContainer with broadcast.
out_proj – The multi-head out-projection layer (a.k.a nn.Linear).
batch_first – If
True
, then the input and output tensors are provided as (…, N, L, E). Default:False
- Examples::
>>> import torch >>> embed_dim, num_heads, bsz = 10, 5, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10])
-
forward
(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, bias_k: Optional[torch.Tensor] = None, bias_v: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]¶ - Parameters
query, key, value – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.
bias_k and bias_v (attn_mask,) – keyword arguments passed to the attention layer. See the definitions in the attention.
- Shape:
Inputs:
query: \((..., L, N, E)\)
key: \((..., S, N, E)\)
value: \((..., S, N, E)\)
attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer.
Outputs:
attn_output: \((..., L, N, E)\)
attn_output_weights: \((N * H, L, S)\)
- Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose).
The MultiheadAttentionContainer module will operate on the last three dimensions.
- where where L is the target length, S is the sequence length, H is the number of attention heads,
N is the batch size, and E is the embedding dimension.
-
InProjContainer¶
-
class
torchtext.nn.modules.multiheadattention.
InProjContainer
(query_proj, key_proj, value_proj)[source]¶ -
__init__
(query_proj, key_proj, value_proj)[source]¶ A in-proj container to project query/key/value in MultiheadAttention. This module happens before reshaping the projected query/key/value into multiple heads. See the linear layers (bottom) of Multi-head Attention in Fig 2 of Attention Is All You Need paper. Also check the usage example in torchtext.nn.MultiheadAttentionContainer.
- Parameters
query_proj – a proj layer for query. A typical projection layer is torch.nn.Linear.
key_proj – a proj layer for key. A typical projection layer is torch.nn.Linear.
value_proj – a proj layer for value. A typical projection layer is torch.nn.Linear.
-
forward
(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶ Projects the input sequences using in-proj layers. query/key/value are simply passed to the forward func of query/key/value_proj, respectively.
- Parameters
key, value (query,) – sequence to be projected
- Examples::
>>> from torchtext.nn import InProjContainer >>> embed_dim, bsz = 10, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> q = torch.rand((5, bsz, embed_dim)) >>> k = v = torch.rand((6, bsz, embed_dim)) >>> q, k, v = in_proj_container(q, k, v)
-
ScaledDotProduct¶
-
class
torchtext.nn.modules.multiheadattention.
ScaledDotProduct
(dropout=0.0, batch_first=False)[source]¶ -
__init__
(dropout=0.0, batch_first=False)[source]¶ Processes a projected query and key-value pair to apply scaled dot product attention.
- Parameters
dropout (float) – probability of dropping an attention weight.
batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature). Default:False
- Examples::
>>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1) >>> q = torch.randn(21, 256, 3) >>> k = v = torch.randn(21, 256, 3) >>> attn_output, attn_weights = SDP(q, k, v) >>> print(attn_output.shape, attn_weights.shape) torch.Size([21, 256, 3]) torch.Size([256, 21, 21])
-
forward
(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, bias_k: Optional[torch.Tensor] = None, bias_v: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]¶ Uses a scaled dot product with the projected key-value pair to update the projected query.
- Parameters
query (Tensor) – Projected query
key (Tensor) – Projected key
value (Tensor) – Projected value
attn_mask (BoolTensor, optional) – 3D mask that prevents attention to certain positions.
and bias_v (bias_k) – (Tensor, optional): one more key and value sequence to be added at sequence dim (dim=-3). Those are used for incremental decoding. Users should provide non-None to both arguments in order to activate them.
- Shape:
query: \((..., L, N * H, E / H)\)
key: \((..., S, N * H, E / H)\)
value: \((..., S, N * H, E / H)\)
- attn_mask: \((N * H, L, S)\), positions with
True
are not allowed to attend while
False
values will be unchanged.
- attn_mask: \((N * H, L, S)\), positions with
bias_k and bias_v:bias: \((1, N * H, E / H)\)
Output: \((..., L, N * H, E / H)\), \((N * H, L, S)\)
- Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose).
The ScaledDotProduct module will operate on the last three dimensions.
where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension.
-