Source code for torchtune.modules.transformer
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from torchtune.modules import MultiHeadAttention
from torchtune.modules.attention_utils import _MaskType
from torchtune.utils._logging import deprecated
[docs]class TransformerSelfAttentionLayer(nn.Module):
"""
Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer.
Args:
attn (MultiHeadAttention): Attention module.
mlp (nn.Module): Feed-forward module.
sa_norm (Optional[nn.Module]): Normalization to be applied before self-attention.
mlp_norm (Optional[nn.Module]): Normalization to be applied before the feed-forward layer.
sa_scale (Optional[nn.Module]): Module to scale self-attention output.
mlp_scale (Optional[nn.Module]): Module to scale the feed-forward output.
"""
def __init__(
self,
attn: MultiHeadAttention,
mlp: nn.Module,
*,
sa_norm: Optional[nn.Module] = None,
mlp_norm: Optional[nn.Module] = None,
sa_scale: Optional[nn.Module] = None,
mlp_scale: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.attn = attn
self.mlp = mlp
self.sa_norm = sa_norm or nn.Identity()
self.mlp_norm = mlp_norm or nn.Identity()
self.sa_scale = sa_scale or nn.Identity()
self.mlp_scale = mlp_scale or nn.Identity()
[docs] def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: int,
decoder_max_seq_len: int,
) -> None:
"""Setup key value caches for attention calculation.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (int): this parameter is ignored in this layer.
decoder_max_seq_len (int): maximum cache sequence length.
"""
self.attn.setup_cache(batch_size, dtype, max_seq_len=decoder_max_seq_len)
[docs] def caches_are_setup(self) -> bool:
"""
Check if the key value caches are setup on ``self.attn``.
See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.
"""
return self.attn.kv_cache is not None
[docs] def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches on ``self.attn`` are enabled.
See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.
"""
return self.attn.cache_enabled
[docs] def forward(
self,
x: torch.Tensor,
*,
mask: Optional[_MaskType] = None,
input_pos: Optional[torch.Tensor] = None,
**kwargs: Dict,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
and before the softmax. Either:
A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
is used by default.
A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence
created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We use
:func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks.
Default is None.
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
**kwargs (Dict): transformer layer inputs not relevant to self attention.
Returns:
torch.Tensor: output tensor with same shape as input
[batch_size x seq_length x embed_dim]
"""
# Input tensor and attention output have the same shape
# [b, s, d]
# Norm applied before self-attention
h = self.sa_norm(x)
attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
h = self.sa_scale(attn_out) + x
# Norm applied before the feedforward layer
mlp_out = self.mlp(self.mlp_norm(h))
# Residual connection; shape: [batch_size, seq_length, embed_dim]
out = h + self.mlp_scale(mlp_out)
return out
[docs]class TransformerCrossAttentionLayer(nn.Module):
"""
Cross attention Transformer layer following the same conventions as the TransformerSelfAttentionLayer.
Normalization is applied before the attention **and** FF layer.
Args:
attn (MultiHeadAttention): Attention module.
mlp (nn.Module): Feed-forward module.
ca_norm (Optional[nn.Module]): Normalization to be applied before cross-attention.
mlp_norm (Optional[nn.Module]): Normalization to be applied before the feed-forward layer.
ca_scale (Optional[nn.Module]): Module to scale cross-attention output.
mlp_scale (Optional[nn.Module]): Module to scale the feed-forward output.
Raises:
AssertionError: if attn.pos_embeddings is set.
"""
def __init__(
self,
attn: MultiHeadAttention,
mlp: nn.Module,
*,
ca_norm: Optional[nn.Module] = None,
mlp_norm: Optional[nn.Module] = None,
ca_scale: Optional[nn.Module] = None,
mlp_scale: Optional[nn.Module] = None,
) -> None:
super().__init__()
if attn.pos_embeddings is not None:
raise AssertionError(
"Doesn't support positional embeddings for cross attention, \
because q and k are different sequences."
)
self.attn = attn
self.mlp = mlp
self.ca_norm = ca_norm or nn.Identity()
self.mlp_norm = mlp_norm or nn.Identity()
self.ca_scale = ca_scale or nn.Identity()
self.mlp_scale = mlp_scale or nn.Identity()
[docs] def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: int,
decoder_max_seq_len: int,
) -> None:
"""Setup key value caches for attention calculation.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (int): maximum cache sequence length.
decoder_max_seq_len (int): this parameter is ignored in this layer.
"""
self.attn.setup_cache(batch_size, dtype, encoder_max_seq_len)
[docs] def caches_are_setup(self) -> bool:
"""
Check if the key value caches are setup on ``self.attn``.
See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.
"""
return self.attn.kv_cache is not None
[docs] def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches on ``self.attn`` are enabled.
See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.
"""
return self.attn.cache_enabled
def _skip_mask(self, mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Some tokens in x may not attend to any encoder inputs
due to the cross attention mask (encoder_mask). This results in
a full row of the attention matrix being masked out.
In the example below, the word "the" is masked from every embedding.
The False value means a token can't attend to an embedding.
.. code-block:: text
|emb||emb||emb|
|The| F F F
|red| T F T
|car| F T T
This results in no inputs into the softmax layer which causes a NaN.
The skip mask is used to mask the outputs of attention and
mlp resulting in the token being skipped.
The above example would result in a skip mask of: [[True], [False], [False]]
which specifies which tokens to fully mask out.
"""
# no skip_mask if no masking
if mask is None:
return None
# negate mask and convert to boolean mask
if mask.dtype == torch.bool:
mask = ~mask
else:
mask = torch.isneginf(mask)
# True where all elements in a row are True
mask = torch.all(mask, dim=-1, keepdim=True)
return mask
[docs] def forward(
self,
x: torch.Tensor,
*,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
**kwargs: Dict,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape
[batch_size x token_sequence x embed_dim]
encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between
tokens and encoder embeddings. A True value at position i,j means token i can attend
to embedding j in the decoder. Mask has shape [batch_size x token_sequence x embed_sequence].
Default is None.
**kwargs (Dict): transformer layer inputs not relevant to self attention.
Returns:
torch.Tensor: output tensor with same shape as input
[batch_size x seq_length x embed_dim]
"""
# During decoding, it's possible encoder_input is None because the embeds
# are already stored in the kv cache.
empty_cache = not self.caches_are_enabled() or self.attn.kv_cache.size == 0
# Skip cross attention when no secondary input as it's primary purpose
# is to attend between x and encoder_input.
if encoder_input is None and empty_cache:
return x
# A mask of tokens (x) with no encoder_input
skip_mask = self._skip_mask(encoder_mask)
if encoder_mask is not None:
# TODO: remove after PyTorch 2.5 is released
# This unmasks the skipped rows to avoid NaNs in SDPA Softmax backward
# This doesn't affect the output since outputs are masked out later
encoder_mask = encoder_mask.masked_fill(skip_mask, True)
# Input tensor and attention output have the same shape
# [b, s, d]
# Norm applied before self-attention
# TODO: Add support for sample packing and bring back input_pos
attn_out = self.attn(self.ca_norm(x), encoder_input, mask=encoder_mask)
if skip_mask is not None:
attn_out = attn_out.masked_fill(skip_mask, 0)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
h = self.ca_scale(attn_out) + x
# Norm applied before the feedforward layer
mlp_out = self.mlp(self.mlp_norm(h))
if skip_mask is not None:
mlp_out = mlp_out.masked_fill(skip_mask, 0)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
out = h + self.mlp_scale(mlp_out)
return out
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
"""
Return a list of ``n`` identical layers.
Args:
module (nn.Module): module to be cloned
n (int): number of clones
Returns:
nn.ModuleList: list of ``n`` identical layers
"""
# FIXME: copy.deepcopy() is not defined on nn.module
return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
[docs]class TransformerDecoder(nn.Module):
"""
Transformer Decoder derived from the Llama2 architecture.
Args:
tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module], nn.ModuleList]): A single transformer Decoder layer, an
nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value. This is used to setup the
:func:`~torchtune.modules.KVCache`
head_dim (int): embedding dimension for each head in self-attention. This is used
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
output (Union[nn.Linear, Callable]): Callable that applies a linear transformation to the output of
the decoder.
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when
layers is not a list.
output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output
Raises:
AssertionError: num_layers is set and layer is a list
AssertionError: num_layers is not set and layer is an nn.Module
Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
in the module where they are used. This helps reduces the number of raise
statements in code and improves readability.
"""
def __init__(
self,
*,
tok_embeddings: nn.Embedding,
layers: Union[nn.Module, List[nn.Module], nn.ModuleList],
max_seq_len: int,
num_heads: int,
head_dim: int,
norm: nn.Module,
output: Union[nn.Linear, Callable],
num_layers: Optional[int] = None,
output_hidden_states: Optional[List[int]] = None,
) -> None:
super().__init__()
if isinstance(layers, nn.ModuleList):
pass
elif isinstance(layers, list):
layers = nn.ModuleList(layers)
else:
if not isinstance(layers, nn.Module):
raise AssertionError("num_layers is defined, layers must be a module")
if num_layers is None:
raise AssertionError("num_layers is not defined, layers must be a list")
layers = _get_clones(layers, num_layers)
self.tok_embeddings = tok_embeddings
self.layers = layers
self.norm = norm
self.output = output
self.output_hidden_states = output_hidden_states or []
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0
# attributes for KV caches during inference
self.encoder_max_cache_seq_len = None
self.decoder_max_cache_seq_len = None
[docs] def set_num_output_chunks(self, num_output_chunks: int) -> None:
"""Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`.
This should be called before the first forward pass, in the recipe."""
self.num_output_chunks = num_output_chunks
[docs] def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: Optional[int] = None,
decoder_max_seq_len: Optional[int] = None,
):
"""
Sets up key-value attention caches for inference. For each layer in ``self.layers``:
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
- :class:`~torchtune.modules.model_fusion.FusionLayer` will use ``decoder_max_seq_len`` and ``encoder_max_seq_len``.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length.
decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length.
"""
has_encoder_layers = any(
isinstance(m, TransformerCrossAttentionLayer) for m in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
)
if has_encoder_layers:
if encoder_max_seq_len is not None:
self.encoder_max_cache_seq_len = encoder_max_seq_len
else:
self.encoder_max_cache_seq_len = self.max_seq_len
if has_decoder_layers:
if decoder_max_seq_len is not None:
self.decoder_max_cache_seq_len = decoder_max_seq_len
else:
self.decoder_max_cache_seq_len = self.max_seq_len
for layer in self.layers:
layer.setup_caches(
batch_size,
dtype,
encoder_max_seq_len=self.encoder_max_cache_seq_len,
decoder_max_seq_len=self.decoder_max_cache_seq_len,
)
[docs] def caches_are_setup(self) -> bool:
"""
Check if the key value caches are setup. This means ``setup_caches`` has been called, and
the relevant attention modules in the model have created their ``KVCache``.
"""
return self.layers[0].caches_are_setup()
[docs] def caches_are_enabled(self) -> bool:
"""
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant
attention modules will be "enabled" and all forward passes will update the caches. This behaviour
can be disabled without altering the state of the KV-caches by "disabling" the KV-caches
using :func:`torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False.
"""
return self.layers[0].caches_are_enabled()
[docs] def reset_caches(self):
"""
Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero,
without deleting or reallocating cache tensors.
Raises:
RuntimeError: if KV-caches are not setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
setup caches first.
"""
if not self.caches_are_enabled():
raise RuntimeError(
"Key value caches are not setup. Call model.setup_caches first."
)
for layer in self.layers:
layer.reset_cache()
[docs] @torch.compiler.disable
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
"""
Apply output projection in chunks. This should be applied in conjunction with
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
To use this method, you should first call
:func:`~torchtune.modules.TransformerDecoder.set_num_output_chunks`.
Args:
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
[b, seq_len, embed_dim].
Returns:
List[torch.Tensor]: List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
"""
return [
self.output(chunk)
for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1)
]
def _validate_inputs(
self,
seq_len: int,
mask: Optional[torch.Tensor] = None,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
):
"""
Validates inputs for ``forward``.
Args:
seq_len (int): Input tensor sequence length.
mask (Optional[torch.Tensor]): Attention mask used for inference and for sequence packing.
encoder_input (Optional[torch.Tensor]): Encoder input for cross-attention.
encoder_mask (Optional[torch.Tensor]): Encoder attention mask for cross-embedding attention.
input_pos (Optional[torch.Tensor]): Input tensor position IDs.
Raises:
ValueError: if seq_len of x is bigger than max_seq_len
ValueError: if the model has caches which have been setup with self-attention layers and ``mask`` is not provided.
ValueError: if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided.
ValueError: if the model has caches which have been setup ``input_pos`` is not provided.
"""
if seq_len > self.max_seq_len:
raise ValueError(
f"seq_len ({seq_len}) of input tensor should be smaller "
f"than max_seq_len ({self.max_seq_len})"
)
if self.caches_are_enabled():
if mask is None:
raise ValueError(
"KV-caches for self-attention layers are setup for inference mode, causal masks must be provided!"
" Use the `mask` arg to provide a causal mask."
)
if encoder_input is not None and encoder_mask is None:
raise ValueError(
"KV-caches for cross-attention/fusion layers are setup for inference mode and you seem to be using"
" encoder_input, causal masks must be provided! Use the `encoder_mask` arg to provide a causal mask."
)
if input_pos is None:
raise ValueError(
"KV-caches are setup for inference mode, input positions must be provided!"
)
[docs] def forward(
self,
tokens: torch.Tensor,
*,
mask: Optional[_MaskType] = None,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Args:
tokens (torch.Tensor): input tensor with shape ``[b x s]``
mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
and before the softmax. This parameter is required during inference if caches have been setup.
Either:
A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
is used by default.
A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence
created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We use
:func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks.
Default is None.
encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape ``[b x s_e x d_e]``
encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between
tokens and encoder embeddings. A True value at position ``i,j`` means token ``i`` can attend
to embedding ``j`` in the decoder. Mask has shape ``[b x s x s_e]``. Default is None,
but this is required during inference if the model has been setup with any layers
which use encoder embeddings and caches have been setup.
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape ``[b x s]``.
During inference, this indicates the position of the current token.
This parameter is required during inference if caches have been setup. Default is None.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` or a list of layer
output tensors defined by ``output_hidden_states`` with the
final output tensor appended to the list.
Note:
At the very first step of inference, when the model is provided with a prompt,
``input_pos`` should contain the positions of all of the tokens in the prompt.
For a single-batch prompt, or a batch of prompts with identical lengths, this
will be ``torch.arange(prompt_length)``. For a batch of varying-length prompts,
shorter prompts are left-padded and position ids are correspondingly right-shifted,
thus positional ids should be of shape ``[b, padded_prompt_length]``.
This is because we will need to retrieve the positional embeddings for each input id.
In the subsequent steps, if the model has been setup with KV-caches, ``input_pos`` will contain
the position(s) of the current token(s) ``torch.tensor([padded_prompt_length])``. Otherwise,
``input_pos`` will contain all the position ids up to the current token.
Shape notation:
- b: batch size
- s: token sequence length
- s_e: encoder sequence length
- v: vocab size
- d: token embed dim
- d_e: encoder embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
seq_len = tokens.shape[1]
self._validate_inputs(
seq_len,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
# shape: [b, s, d]
h = self.tok_embeddings(tokens)
hidden = []
for i, layer in enumerate(self.layers):
if i in self.output_hidden_states:
hidden.append(h)
# shape: [b, s, d]
h = layer(
h,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
# shape: [b, s, d]
h = self.norm(h)
if self.num_output_chunks > 0:
output = self.chunked_output(h)
else:
# shape: [b, seq_len, out_dim]
output = self.output(h).float()
# Output list if hidden states are requested, otherwise just the output
# TODO: always output a list to have a consistent output type
output = output if not hidden else [*hidden, output]
return output
@deprecated(
msg="Please use torchtune.modules.TransformerDecoder instead. \
If you need an example, see torchtune.models.qwen2._component_builders.py \
on how to use torch.modules.TiedLinear for the output projection."
)
class TiedEmbeddingTransformerDecoder(nn.Module):
"""
Transformer Decoder with tied embedding weight. A key difference between
this class and :class:`~torchtune.modules.TransformerDecoder`
is that the output projection is replaced with token embeddings weights.
Args:
tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module]]): Transformer Decoder layer or a list of layers.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value. This is used to setup the
:func:`~torchtune.modules.KVCache`
head_dim (int): embedding dimension for each head in self-attention. This is used
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when
layers is not a list.
output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output
Raises:
AssertionError: num_layers is set and layer is a list
AssertionError: num_layers is not set and layer is an nn.Module
Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
in the module where they are used. This helps reduces the number of raise
statements in code and improves readability.
"""
def __init__(
self,
*,
tok_embeddings: nn.Embedding,
layers: Union[nn.Module, List[nn.Module]],
max_seq_len: int,
num_heads: int,
head_dim: int,
norm: nn.Module,
num_layers: Optional[int] = None,
output_hidden_states: Optional[List[int]] = None,
) -> None:
super().__init__()
if num_layers is None:
if isinstance(layers, nn.Module):
raise AssertionError(
"If num_layers is undefined, it is assumed that a list of layers is provided."
)
layers = nn.ModuleList(layers)
else:
if not isinstance(layers, nn.Module):
raise AssertionError("num_layers is defined, layers must be a module")
layers = _get_clones(layers, num_layers)
self.tok_embeddings = tok_embeddings
self.layers = layers
self.norm = norm
self.output_hidden_states = output_hidden_states or []
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0
# attributes for KV caches during inference
self.encoder_max_cache_seq_len = None
self.decoder_max_cache_seq_len = None
@torch.compiler.disable
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
"""
Apply output projection in chunks. This should be applied in conjunction with
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
To use this method, you should first call
:func:`~torchtune.modules.TiedEmbeddingTransformerDecoder.set_num_output_chunks`.
Args:
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
[b, seq_len, embed_dim].
Returns:
List[torch.Tensor]: List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
"""
return [
F.linear(chunk, self.tok_embeddings.weight)
for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1)
]
def set_num_output_chunks(self, num_output_chunks: int) -> None:
"""Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`.
This should be called before the first forward pass, in the recipe."""
self.num_output_chunks = num_output_chunks
def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: Optional[int] = None,
decoder_max_seq_len: Optional[int] = None,
):
"""
Sets up key-value attention caches for inference. For each layer in ``self.layers``:
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
- :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length.
decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length.
"""
has_encoder_layers = any(
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
)
if has_encoder_layers:
if encoder_max_seq_len is not None:
self.encoder_max_cache_seq_len = encoder_max_seq_len
else:
self.encoder_max_cache_seq_len = self.max_seq_len
if has_decoder_layers:
if decoder_max_seq_len is not None:
self.decoder_max_cache_seq_len = decoder_max_seq_len
else:
self.decoder_max_cache_seq_len = self.decoder_max_cache_seq_len
for layer in self.layers:
layer.setup_caches(
batch_size,
dtype,
self.encoder_max_cache_seq_len,
self.decoder_max_cache_seq_len,
)
@property
def encoder_caches_are_enabled(self) -> bool:
"""Checks if there are any :class:`~torchtune.modules.TransformerCrossAttentionLayer`,
or :class:`~torchtune.modules.fusion.FusionLayer` layers which have cache enabled.
"""
return self.encoder_max_cache_seq_len is not None
@property
def decoder_caches_are_enabled(self) -> bool:
"""Check if the key value caches are setup."""
return self.decoder_max_cache_seq_len is not None
def reset_caches(self):
"""Reset the key value caches."""
if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled):
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)
for layer in self.layers:
layer.reset_cache()
def forward(
self,
tokens: torch.Tensor,
*,
mask: Optional[_MaskType] = None,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Args:
tokens (torch.Tensor): input tensor with shape [b x s]
mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
and before the softmax. Either a boolean tensor with shape [b x s x s] or a
:class:`~torch.nn.attention.flex_attention.BlockMask`. If a boolean tensor, a value
of True in row i and column j means token i attends to token j. A value of False means
token i does not attend to token j. If no mask is specified, a causal mask
is used by default. If a :class:`~torch.nn.attention.flex_attention.BlockMask` is passed
for document masking in a packed sequence via `create_block_mask
<https://pytorch.org/blog/flexattention/#mask-mods>`_, we use
:func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention.
Default is None.
encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e]
encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between
tokens and encoder embeddings. A True value at position i,j means token i can attend
to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None.
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Note: At the very first step of inference, when the model is provided with a prompt,
``input_pos`` would contain the positions of all of the tokens in the prompt
(eg: ``torch.arange(prompt_length)``). This is because we will need to compute the
KV values for each position.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape [b x s x v] or a list of layer
output tensors defined by ``output_hidden_states`` with the
final output tensor appended to the list.
Raises:
ValueError: if seq_len of x is bigger than max_seq_len
ValueError: if a mask is provided and the model is in inference mode
Notation used for tensor shapes:
- b: batch size
- s: token sequence length
- s_e: encoder sequence length
- v: vocab size
- d: token embed dim
- d_e: encoder embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape
if seq_len > self.max_seq_len:
raise ValueError(
f"seq_len ({seq_len}) of input tensor should be smaller "
f"than max_seq_len ({self.max_seq_len})"
)
# shape: [b, s, d]
h = self.tok_embeddings(tokens)
if self.decoder_caches_are_enabled:
if mask is None:
raise ValueError(
"KV-caches for self-attention layers are setup for inference mode, masks must be provided!"
" Use the `mask` arg to provide a mask."
)
if self.encoder_caches_are_enabled:
if encoder_mask is None:
raise ValueError(
"KV-caches for cross-attention/fusion layers are setup for inference mode, encoder masks must be provided!"
" Use the `encoder_mask` arg to provide an encoder mask."
)
if (
self.encoder_caches_are_enabled
or self.decoder_caches_are_enabled
and input_pos is None
):
raise ValueError(
"KV-caches are setup for inference mode, input positions must be provided!"
)
hidden = []
for i, layer in enumerate(self.layers):
if i in self.output_hidden_states:
hidden.append(h)
# shape: [b, s, d]
h = layer(
h,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
# shape: [b, s, d]
h = self.norm(h)
if self.num_output_chunks > 0:
output = self.chunked_output(h)
else:
# shape: [b, seq_len, out_dim]
output = F.linear(h, self.tok_embeddings.weight).float()
# Output list if hidden states are requested, otherwise just the output
# TODO: always output a list to have a consistent output type
output = output if not hidden else [*hidden, output]
return output