Source code for torch.nn.modules.transformer
# mypy: allow-untyped-defs
import copy
from typing import Optional, Any, Union, Callable
import torch
import warnings
from torch import Tensor
from .. import functional as F
from .module import Module
from .activation import MultiheadAttention
from .container import ModuleList
from ..init import xavier_uniform_
from .dropout import Dropout
from .linear import Linear
from .normalization import LayerNorm
__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
def _generate_square_subsequent_mask(
sz: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
r"""Generate a square causal mask for the sequence.
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
"""
if device is None:
device = torch.device('cpu')
if dtype is None:
dtype = torch.float32
return torch.triu(
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
diagonal=1,
)
def _get_seq_len(
src: Tensor,
batch_first: bool
) -> Optional[int]:
if src.is_nested:
return None
else:
src_size = src.size()
if len(src_size) == 2:
# unbatched: S, E
return src_size[0]
else:
# batched: B, S, E if batch_first else S, B, E
seq_len_pos = 1 if batch_first else 0
return src_size[seq_len_pos]
[docs]class Transformer(Module):
r"""A transformer model.
User is able to modify the attributes as needed. The architecture
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
Processing Systems, pages 6000-6010.
Args:
d_model: the number of expected features in the encoder/decoder inputs (default=512).
nhead: the number of heads in the multiheadattention models (default=8).
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of encoder/decoder intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
custom_encoder: custom encoder (default=None).
custom_decoder: custom decoder (default=None).
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
other attention and feedforward operations, otherwise after. Default: ``False`` (after).
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
bias. Default: ``True``.
Examples::
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
>>> src = torch.rand((10, 32, 512))
>>> tgt = torch.rand((20, 32, 512))
>>> out = transformer_model(src, tgt)
Note: A full example to apply nn.Transformer module for the word language model is available in
https://github.com/pytorch/examples/tree/master/word_language_model
"""
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
bias: bool = True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
if custom_encoder is not None:
self.encoder = custom_encoder
else:
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first, norm_first,
bias, **factory_kwargs)
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first, norm_first,
bias, **factory_kwargs)
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
self.batch_first = batch_first
[docs] def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None,
memory_is_causal: bool = False) -> Tensor:
r"""Take in and process masked source/target sequences.
.. note::
If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
not allowed to participate in the attention,
which is the opposite of the definition for :attr:`attn_mask`
in :func:`torch.nn.functional.scaled_dot_product_attention`.
Args:
src: the sequence to the encoder (required).
tgt: the sequence to the decoder (required).
src_mask: the additive mask for the src sequence (optional).
tgt_mask: the additive mask for the tgt sequence (optional).
memory_mask: the additive mask for the encoder output (optional).
src_key_padding_mask: the Tensor mask for src keys per batch (optional).
tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
src_is_causal: If specified, applies a causal mask as ``src_mask``.
Default: ``None``; try to detect a causal mask.
Warning:
``src_is_causal`` provides a hint that ``src_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
Default: ``None``; try to detect a causal mask.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
``memory_mask``.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
- src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
`(N, S, E)` if `batch_first=True`.
- tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
`(N, T, E)` if `batch_first=True`.
- src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
- tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
- memory_mask: :math:`(T, S)`.
- src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
- tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
- memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
positions. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
[src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
the attention. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
`(N, T, E)` if `batch_first=True`.
Note: Due to the multi-head attention architecture in the transformer model,
the output sequence length of a transformer is same as the input sequence
(i.e. target) length of the decoder.
where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
batch size, :math:`E` is the feature number
Examples:
>>> # xdoctest: +SKIP
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
"""
is_batched = src.dim() == 3
if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
is_causal=src_is_causal)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
return output
[docs] @staticmethod
def generate_square_subsequent_mask(
sz: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
r"""Generate a square causal mask for the sequence.
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
"""
return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
[docs]class TransformerEncoder(Module):
r"""TransformerEncoder is a stack of N encoder layers.
Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ['norm']
def __init__(
self,
encoder_layer: "TransformerEncoderLayer",
num_layers: int,
norm: Optional[Module] = None,
enable_nested_tensor: bool = True,
mask_check: bool = True
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
# this attribute saves the value providedat object construction
self.enable_nested_tensor = enable_nested_tensor
# this attribute controls whether nested tensors are used
self.use_nested_tensor = enable_nested_tensor
self.mask_check = mask_check
enc_layer = "encoder_layer"
why_not_sparsity_fast_path = ''
if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
elif encoder_layer.norm_first :
why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
elif not encoder_layer.self_attn.batch_first:
why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
"(use batch_first for better inference performance)")
elif not encoder_layer.self_attn._qkv_same_embed_dim:
why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
elif encoder_layer.self_attn.in_proj_bias is None:
why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
elif not encoder_layer.activation_relu_or_gelu:
why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
elif encoder_layer.self_attn.num_heads % 2 == 1:
why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
if enable_nested_tensor and why_not_sparsity_fast_path:
warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
self.use_nested_tensor = False
[docs] def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
is_causal: Optional[bool] = None) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
is_causal: If specified, applies a causal mask as ``mask``.
Default: ``None``; try to detect a causal mask.
Warning:
``is_causal`` provides a hint that ``mask`` is the
causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
Shape:
see the docs in :class:`~torch.nn.Transformer`.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
other_type=F._none_or_dtype(mask),
other_name="mask",
target_type=src.dtype
)
mask = F._canonical_mask(
mask=mask,
mask_name="mask",
other_type=None,
other_name="",
target_type=src.dtype,
check_other=False,
)
output = src
convert_to_nested = False
first_layer = self.layers[0]
src_key_padding_mask_for_layers = src_key_padding_mask
why_not_sparsity_fast_path = ''
str_first_layer = "self.layers[0]"
batch_first = first_layer.self_attn.batch_first
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
if not is_fastpath_enabled:
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
elif not hasattr(self, "use_nested_tensor"):
why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
elif not self.use_nested_tensor:
why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
elif first_layer.training:
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
elif not src.dim() == 3:
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
elif src_key_padding_mask is None:
why_not_sparsity_fast_path = "src_key_padding_mask was None"
elif (((not hasattr(self, "mask_check")) or self.mask_check)
and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
elif output.is_nested:
why_not_sparsity_fast_path = "NestedTensor input is not supported"
elif mask is not None:
why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
elif torch.is_autocast_enabled():
why_not_sparsity_fast_path = "autocast is enabled"
if not why_not_sparsity_fast_path:
tensor_args = (
src,
first_layer.self_attn.in_proj_weight,
first_layer.self_attn.in_proj_bias,
first_layer.self_attn.out_proj.weight,
first_layer.self_attn.out_proj.bias,
first_layer.norm1.weight,
first_layer.norm1.bias,
first_layer.norm2.weight,
first_layer.norm2.bias,
first_layer.linear1.weight,
first_layer.linear1.bias,
first_layer.linear2.weight,
first_layer.linear2.bias,
)
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
if torch.overrides.has_torch_function(tensor_args):
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
elif src.device.type not in _supported_device_type:
why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
convert_to_nested = True
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
src_key_padding_mask_for_layers = None
seq_len = _get_seq_len(src, batch_first)
is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
for mod in self.layers:
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
if convert_to_nested:
output = output.to_padded_tensor(0., src.size())
if self.norm is not None:
output = self.norm(output)
return output
[docs]class TransformerDecoder(Module):
r"""TransformerDecoder is a stack of N decoder layers.
Args:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
"""
__constants__ = ['norm']
def __init__(
self,
decoder_layer: "TransformerDecoderLayer",
num_layers: int,
norm: Optional[Module] = None
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
[docs] def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
memory_is_causal: bool = False) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
Default: ``None``; try to detect a causal mask.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
``memory mask``.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
see the docs in :class:`~torch.nn.Transformer`.
"""
output = tgt
seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
for mod in self.layers:
output = mod(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal,
memory_is_causal=memory_is_causal)
if self.norm is not None:
output = self.norm(output)
return output
[docs]class TransformerEncoderLayer(Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
TransformerEncoderLayer can handle either traditional torch.tensor inputs,
or Nested Tensor inputs. Derived classes are expected to similarly accept
both input formats. (Not all combinations of inputs are currently
supported by TransformerEncoderLayer while Nested Tensor is in prototype
state.)
If you are implementing a custom layer, you may derive it either from
the Module or TransformerEncoderLayer class. If your custom layer
supports both torch.Tensors and Nested Tensors inputs, make its
implementation a derived class of TransformerEncoderLayer. If your custom
Layer supports only torch.Tensor inputs, derive its implementation from
Module.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first: if ``True``, layer norm is done prior to attention and feedforward
operations, respectively. Otherwise it's done after. Default: ``False`` (after).
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
bias. Default: ``True``.
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
Alternatively, when ``batch_first`` is ``True``:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
Fast path:
forward() will use a special optimized implementation described in
`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
conditions are met:
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
argument ``requires_grad``
- training is disabled (using ``.eval()``)
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
nor ``src_key_padding_mask`` is passed
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
unless the caller has manually modified one without modifying the other)
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
passed for ``src`` to represent padding more efficiently than using a padding
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
returned, and an additional speedup proportional to the fraction of the input that
is padding can be expected.
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
https://arxiv.org/abs/2205.14135
"""
__constants__ = ['norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
bias: bool = True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
bias=bias, batch_first=batch_first,
**factory_kwargs)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
# We can't test self.activation in forward() in TorchScript,
# so stash some information about it instead.
if activation is F.relu or isinstance(activation, torch.nn.ReLU):
self.activation_relu_or_gelu = 1
elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
self.activation = activation
def __setstate__(self, state):
super().__setstate__(state)
if not hasattr(self, 'activation'):
self.activation = F.relu
[docs] def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
is_causal: bool = False) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
is_causal: If specified, applies a causal mask as ``src mask``.
Default: ``False``.
Warning:
``is_causal`` provides a hint that ``src_mask`` is the
causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
Shape:
see the docs in :class:`~torch.nn.Transformer`.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
other_type=F._none_or_dtype(src_mask),
other_name="src_mask",
target_type=src.dtype
)
src_mask = F._canonical_mask(
mask=src_mask,
mask_name="src_mask",
other_type=None,
other_name="",
target_type=src.dtype,
check_other=False,
)
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
why_not_sparsity_fast_path = ''
if not is_fastpath_enabled:
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
elif not src.dim() == 3:
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
elif self.training:
why_not_sparsity_fast_path = "training is enabled"
elif not self.self_attn.batch_first:
why_not_sparsity_fast_path = "self_attn.batch_first was not True"
elif self.self_attn.in_proj_bias is None:
why_not_sparsity_fast_path = "self_attn was passed bias=False"
elif not self.self_attn._qkv_same_embed_dim:
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
elif not self.activation_relu_or_gelu:
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
elif not (self.norm1.eps == self.norm2.eps):
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
elif self.self_attn.num_heads % 2 == 1:
why_not_sparsity_fast_path = "num_head is odd"
elif torch.is_autocast_enabled():
why_not_sparsity_fast_path = "autocast is enabled"
if not why_not_sparsity_fast_path:
tensor_args = (
src,
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.norm1.weight,
self.norm1.bias,
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
self.linear1.bias,
self.linear2.weight,
self.linear2.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
if torch.overrides.has_torch_function(tensor_args):
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
elif not all((x.device.type in _supported_device_type) for x in tensor_args):
why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
f"{_supported_device_type}")
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_sparsity_fast_path:
merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
return torch._transformer_encoder_layer_fwd(
src,
self.self_attn.embed_dim,
self.self_attn.num_heads,
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.activation_relu_or_gelu == 2,
self.norm_first,
self.norm1.eps,
self.norm1.weight,
self.norm1.bias,
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
self.linear1.bias,
self.linear2.weight,
self.linear2.bias,
merged_mask,
mask_type,
)
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
x = self.norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False, is_causal=is_causal)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
[docs]class TransformerDecoderLayer(Module):
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first: if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectively. Otherwise it's done after.
Default: ``False`` (after).
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
bias. Default: ``True``.
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = decoder_layer(tgt, memory)
Alternatively, when ``batch_first`` is ``True``:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
>>> memory = torch.rand(32, 10, 512)
>>> tgt = torch.rand(32, 20, 512)
>>> out = decoder_layer(tgt, memory)
"""
__constants__ = ['norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
bias: bool = True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
bias=bias, **factory_kwargs)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
bias=bias, **factory_kwargs)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super().__setstate__(state)
[docs] def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
Default: ``False``.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
``memory mask``.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
see the docs in :class:`~torch.nn.Transformer`.
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False)[0]
return self.dropout1(x)
# multihead attention block
def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False)[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
def _get_clones(module, N):
# FIXME: copy.deepcopy() is not defined on nn.module
return ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError(f"activation should be relu/gelu, not {activation}")
def _detect_is_causal_mask(
mask: Optional[Tensor],
is_causal: Optional[bool] = None,
size: Optional[int] = None,
) -> bool:
"""Return whether the given attention mask is causal.
Warning:
If ``is_causal`` is not ``None``, its value will be returned as is. If a
user supplies an incorrect ``is_causal`` hint,
``is_causal=False`` when the mask is in fact a causal attention.mask
may lead to reduced performance relative to what would be achievable
with ``is_causal=True``;
``is_causal=True`` when the mask is in fact not a causal attention.mask
may lead to incorrect and unpredictable execution - in some scenarios,
a causal mask may be applied based on the hint, in other execution
scenarios the specified mask may be used. The choice may not appear
to be deterministic, in that a number of factors like alignment,
hardware SKU, etc influence the decision whether to use a mask or
rely on the hint.
``size`` if not None, check whether the mask is a causal mask of the provided size
Otherwise, checks for any causal mask.
"""
# Prevent type refinement
make_causal = (is_causal is True)
if is_causal is None and mask is not None:
sz = size if size is not None else mask.size(-2)
causal_comparison = _generate_square_subsequent_mask(
sz, device=mask.device, dtype=mask.dtype)
# Do not use `torch.equal` so we handle batched masks by
# broadcasting the comparison.
if mask.size() == causal_comparison.size():
make_causal = bool((mask == causal_comparison).all())
else:
make_causal = False
return make_causal