Source code for torchtune.modules.transformer
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from torchtune.modules import MultiHeadAttention
from torchtune.modules.attention_utils import _MaskType
[docs]class TransformerSelfAttentionLayer(nn.Module):
"""
Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer.
Args:
attn (MultiHeadAttention): Attention module.
mlp (nn.Module): Feed-forward module.
sa_norm (Optional[nn.Module]): Normalization to be applied before self-attention.
mlp_norm (Optional[nn.Module]): Normalization to be applied before the feed-forward layer.
sa_scale (Optional[nn.Module]): Module to scale self-attention output.
mlp_scale (Optional[nn.Module]): Module to scale the feed-forward output.
"""
def __init__(
self,
attn: MultiHeadAttention,
mlp: nn.Module,
*,
sa_norm: Optional[nn.Module] = None,
mlp_norm: Optional[nn.Module] = None,
sa_scale: Optional[nn.Module] = None,
mlp_scale: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.attn = attn
self.mlp = mlp
self.sa_norm = sa_norm or nn.Identity()
self.mlp_norm = mlp_norm or nn.Identity()
self.sa_scale = sa_scale or nn.Identity()
self.mlp_scale = mlp_scale or nn.Identity()
[docs] def setup_cache(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: int,
decoder_max_seq_len: int,
) -> None:
"""Setup key value caches for attention calculation.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (int): this parameter is ignored in this layer.
decoder_max_seq_len (int): maximum cache sequence length.
"""
self.attn.setup_cache(batch_size, dtype, max_seq_len=decoder_max_seq_len)
@property
def cache_enabled(self) -> bool:
"""Check if the key value caches are setup."""
return self.attn.kv_cache is not None
[docs] def forward(
self,
x: torch.Tensor,
*,
mask: Optional[_MaskType] = None,
input_pos: Optional[torch.Tensor] = None,
**kwargs: Dict,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
and before the softmax. Either:
A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
is used by default.
A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence
created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We use
:func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks.
Default is None.
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
**kwargs (Dict): transformer layer inputs not relevant to self attention.
Returns:
torch.Tensor: output tensor with same shape as input
[batch_size x seq_length x embed_dim]
"""
# Input tensor and attention output have the same shape
# [b, s, d]
# Norm applied before self-attention
h = self.sa_norm(x)
attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
h = self.sa_scale(attn_out) + x
# Norm applied before the feedforward layer
mlp_out = self.mlp(self.mlp_norm(h))
# Residual connection; shape: [batch_size, seq_length, embed_dim]
out = h + self.mlp_scale(mlp_out)
return out
[docs]class TransformerCrossAttentionLayer(nn.Module):
"""
Cross attention Transformer layer following the same conventions as the TransformerSelfAttentionLayer.
Normalization is applied before the attention **and** FF layer.
Args:
attn (MultiHeadAttention): Attention module.
mlp (nn.Module): Feed-forward module.
ca_norm (Optional[nn.Module]): Normalization to be applied before cross-attention.
mlp_norm (Optional[nn.Module]): Normalization to be applied before the feed-forward layer.
ca_scale (Optional[nn.Module]): Module to scale cross-attention output.
mlp_scale (Optional[nn.Module]): Module to scale the feed-forward output.
Raises:
AssertionError: if attn.pos_embeddings is set.
"""
def __init__(
self,
attn: MultiHeadAttention,
mlp: nn.Module,
*,
ca_norm: Optional[nn.Module] = None,
mlp_norm: Optional[nn.Module] = None,
ca_scale: Optional[nn.Module] = None,
mlp_scale: Optional[nn.Module] = None,
) -> None:
super().__init__()
if attn.pos_embeddings is not None:
raise AssertionError(
"Doesn't support positional embeddings for cross attention, \
because q and k are different sequences."
)
self.attn = attn
self.mlp = mlp
self.ca_norm = ca_norm or nn.Identity()
self.mlp_norm = mlp_norm or nn.Identity()
self.ca_scale = ca_scale or nn.Identity()
self.mlp_scale = mlp_scale or nn.Identity()
[docs] def setup_cache(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: int,
decoder_max_seq_len: int,
) -> None:
"""Setup key value caches for attention calculation.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (int): maximum cache sequence length.
decoder_max_seq_len (int): this parameter is ignored in this layer.
"""
self.attn.setup_cache(batch_size, dtype, encoder_max_seq_len)
@property
def cache_enabled(self) -> bool:
"""Check if the key value caches are setup."""
return self.attn.kv_cache is not None
def _skip_mask(self, mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Some tokens in x may not attend to any encoder inputs
due to the cross attention mask (encoder_mask). This results in
a full row of the attention matrix being masked out.
In the example below, the word "the" is masked from every embedding.
The False value means a token can't attend to an embedding.
.. code-block:: text
|emb||emb||emb|
|The| F F F
|red| T F T
|car| F T T
This results in no inputs into the softmax layer which causes a NaN.
The skip mask is used to mask the outputs of attention and
mlp resulting in the token being skipped.
The above example would result in a skip mask of: [[True], [False], [False]]
which specifies which tokens to fully mask out.
"""
# no skip_mask if no masking
if mask is None:
return None
# negate mask and convert to boolean mask
if mask.dtype == torch.bool:
mask = ~mask
else:
mask = torch.isneginf(mask)
# True where all elements in a row are True
mask = torch.all(mask, dim=-1, keepdim=True)
return mask
[docs] def forward(
self,
x: torch.Tensor,
*,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
**kwargs: Dict,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape
[batch_size x token_sequence x embed_dim]
encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between
tokens and encoder embeddings. A True value at position i,j means token i can attend
to embedding j in the decoder. Mask has shape [batch_size x token_sequence x embed_sequence].
Default is None.
**kwargs (Dict): transformer layer inputs not relevant to self attention.
Returns:
torch.Tensor: output tensor with same shape as input
[batch_size x seq_length x embed_dim]
"""
# During decoding, it's possible encoder_input is None because the embeds
# are already stored in the kv cache.
empty_cache = not self.cache_enabled or self.attn.kv_cache.size == 0
# Skip cross attention when no secondary input as it's primary purpose
# is to attend between x and encoder_input.
if encoder_input is None and empty_cache:
return x
# A mask of tokens (x) with no encoder_input
skip_mask = self._skip_mask(encoder_mask)
if encoder_mask is not None:
# TODO: remove after PyTorch 2.5 is released
# This unmasks the skipped rows to avoid NaNs in SDPA Softmax backward
# This doesn't affect the output since outputs are masked out later
encoder_mask = encoder_mask.masked_fill(skip_mask, True)
# Input tensor and attention output have the same shape
# [b, s, d]
# Norm applied before self-attention
# TODO: Add support for sample packing and bring back input_pos
attn_out = self.attn(self.ca_norm(x), encoder_input, mask=encoder_mask)
if skip_mask is not None:
attn_out = attn_out.masked_fill(skip_mask, 0)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
h = self.ca_scale(attn_out) + x
# Norm applied before the feedforward layer
mlp_out = self.mlp(self.mlp_norm(h))
if skip_mask is not None:
mlp_out = mlp_out.masked_fill(skip_mask, 0)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
out = h + self.mlp_scale(mlp_out)
return out
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
"""
Return a list of ``n`` identical layers.
Args:
module (nn.Module): module to be cloned
n (int): number of clones
Returns:
nn.ModuleList: list of ``n`` identical layers
"""
# FIXME: copy.deepcopy() is not defined on nn.module
return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
[docs]class TransformerDecoder(nn.Module):
"""
Transformer Decoder derived from the Llama2 architecture.
Args:
tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module], nn.ModuleList]): A single transformer Decoder layer, an
nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value. This is used to setup the
:func:`~torchtune.modules.KVCache`
head_dim (int): embedding dimension for each head in self-attention. This is used
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
output (Union[nn.Linear, Callable]): Callable that applies a linear transformation to the output of
the decoder.
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when
layers is not a list.
output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output
Raises:
AssertionError: num_layers is set and layer is a list
AssertionError: num_layers is not set and layer is an nn.Module
Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
in the module where they are used. This helps reduces the number of raise
statements in code and improves readability.
"""
def __init__(
self,
*,
tok_embeddings: nn.Embedding,
layers: Union[nn.Module, List[nn.Module], nn.ModuleList],
max_seq_len: int,
num_heads: int,
head_dim: int,
norm: nn.Module,
output: Union[nn.Linear, Callable],
num_layers: Optional[int] = None,
output_hidden_states: Optional[List[int]] = None,
) -> None:
super().__init__()
if isinstance(layers, nn.ModuleList):
pass
elif isinstance(layers, list):
layers = nn.ModuleList(layers)
else:
if not isinstance(layers, nn.Module):
raise AssertionError("num_layers is defined, layers must be a module")
if num_layers is None:
raise AssertionError("num_layers is not defined, layers must be a list")
layers = _get_clones(layers, num_layers)
self.tok_embeddings = tok_embeddings
self.layers = layers
self.norm = norm
self.output = output
self.output_hidden_states = output_hidden_states or []
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0
# attributes for KV caches during inference
self.encoder_max_cache_seq_len = None
self.decoder_max_cache_seq_len = None
[docs] def set_num_output_chunks(self, num_output_chunks: int) -> None:
"""Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`.
This should be called before the first forward pass, in the recipe."""
self.num_output_chunks = num_output_chunks
[docs] def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: Optional[int] = None,
decoder_max_seq_len: Optional[int] = None,
):
"""
Sets up key-value attention caches for inference. For each layer in ``self.layers``:
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
- :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length.
decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length.
"""
has_encoder_layers = any(
isinstance(m, TransformerCrossAttentionLayer) for m in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
)
if has_encoder_layers:
if encoder_max_seq_len is not None:
self.encoder_max_cache_seq_len = encoder_max_seq_len
else:
self.encoder_max_cache_seq_len = self.max_seq_len
if has_decoder_layers:
if decoder_max_seq_len is not None:
self.decoder_max_cache_seq_len = decoder_max_seq_len
else:
self.decoder_max_cache_seq_len = self.max_seq_len
for layer in self.layers:
layer.setup_cache(
batch_size,
dtype,
encoder_max_seq_len=self.encoder_max_cache_seq_len,
decoder_max_seq_len=self.decoder_max_cache_seq_len,
)
[docs] def caches_are_enabled(self) -> bool:
"""Check if the key value caches are setup. This is useful to efficient inference."""
return self.layers[0].cache_enabled
[docs] def reset_caches(self):
"""Reset the key value caches."""
if not self.caches_are_enabled():
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)
for layer in self.layers:
layer.reset_cache()
[docs] @torch.compiler.disable
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
"""
Apply output projection in chunks. This should be applied in conjunction with
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
To use this method, you should first call
:func:`~torchtune.modules.TransformerDecoder.set_num_output_chunks`.
Args:
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
[b, seq_len, embed_dim].
Returns:
List[torch.Tensor]: List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
"""
return [
self.output(chunk)
for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1)
]
def _validate_inputs(
self,
seq_len: int,
mask: Optional[torch.Tensor] = None,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
):
"""
Validates inputs for ``forward``.
Args:
seq_len (int): Input tensor sequence length.
mask (Optional[torch.Tensor]): Attention mask used for inference and for sequence packing.
encoder_input (Optional[torch.Tensor]): Encoder input for cross-attention.
encoder_mask (Optional[torch.Tensor]): Encoder attention mask for cross-embedding attention.
input_pos (Optional[torch.Tensor]): Input tensor position IDs.
Raises:
ValueError: if seq_len of x is bigger than max_seq_len
ValueError: if the model has caches which have been setup with self-attention layers and ``mask`` is not provided.
ValueError: if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided.
ValueError: if the model has caches which have been setup ``input_pos`` is not provided.
"""
if seq_len > self.max_seq_len:
raise ValueError(
f"seq_len ({seq_len}) of input tensor should be smaller "
f"than max_seq_len ({self.max_seq_len})"
)
if self.caches_are_enabled():
if mask is None:
raise ValueError(
"KV-caches for self-attention layers are setup for inference mode, causal masks must be provided!"
" Use the `mask` arg to provide a causal mask."
)
if encoder_input is not None and encoder_mask is None:
raise ValueError(
"KV-caches for cross-attention/fusion layers are setup for inference mode and you seem to be using"
" encoder_input, causal masks must be provided! Use the `encoder_mask` arg to provide a causal mask."
)
if input_pos is None:
raise ValueError(
"KV-caches are setup for inference mode, input positions must be provided!"
)
[docs] def forward(
self,
tokens: torch.Tensor,
*,
mask: Optional[_MaskType] = None,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Args:
tokens (torch.Tensor): input tensor with shape ``[b x s]``
mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
and before the softmax. This parameter is required during inference if caches have been setup.
Either:
A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
is used by default.
A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence
created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We use
:func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks.
Default is None.
encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape ``[b x s_e x d_e]``
encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between
tokens and encoder embeddings. A True value at position ``i,j`` means token ``i`` can attend
to embedding ``j`` in the decoder. Mask has shape ``[b x s x s_e]``. Default is None,
but this is required during inference if the model has been setup with any layers
which use encoder embeddings and caches have been setup.
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape ``[b x s]``.
During inference, this indicates the position of the current token.
This parameter is required during inference if caches have been setup. Default is None.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` or a list of layer
output tensors defined by ``output_hidden_states`` with the
final output tensor appended to the list.
Note:
At the very first step of inference, when the model is provided with a prompt,
``input_pos`` should contain the positions of all of the tokens in the prompt.
For a single-batch prompt, or a batch of prompts with identical lengths, this
will be ``torch.arange(prompt_length)``. For a batch of varying-length prompts,
shorter prompts are left-padded and position ids are correspondingly right-shifted,
thus positional ids should be of shape ``[b, padded_prompt_length]``.
This is because we will need to retrieve the positional embeddings for each input id.
In the subsequent steps, if the model has been setup with KV-caches, ``input_pos`` will contain
the position(s) of the current token(s) ``torch.tensor([padded_prompt_length])``. Otherwise,
``input_pos`` will contain all the position ids up to the current token.
Shape notation:
- b: batch size
- s: token sequence length
- s_e: encoder sequence length
- v: vocab size
- d: token embed dim
- d_e: encoder embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape
self._validate_inputs(
seq_len,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
# shape: [b, s, d]
h = self.tok_embeddings(tokens)
hidden = []
for i, layer in enumerate(self.layers):
if i in self.output_hidden_states:
hidden.append(h)
# shape: [b, s, d]
h = layer(
h,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
# shape: [b, s, d]
h = self.norm(h)
if self.num_output_chunks > 0:
output = self.chunked_output(h)
else:
# shape: [b, seq_len, out_dim]
output = self.output(h).float()
# Output list if hidden states are requested, otherwise just the output
# TODO: always output a list to have a consistent output type
output = output if not hidden else [*hidden, output]
return output
class TiedEmbeddingTransformerDecoder(nn.Module):
"""
Transformer Decoder with tied embedding weight. A key difference between
this class and :class:`~torchtune.modules.TransformerDecoder`
is that the output projection is replaced with token embeddings weights.
Args:
tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module]]): Transformer Decoder layer or a list of layers.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value. This is used to setup the
:func:`~torchtune.modules.KVCache`
head_dim (int): embedding dimension for each head in self-attention. This is used
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when
layers is not a list.
output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output
Raises:
AssertionError: num_layers is set and layer is a list
AssertionError: num_layers is not set and layer is an nn.Module
Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
in the module where they are used. This helps reduces the number of raise
statements in code and improves readability.
"""
def __init__(
self,
*,
tok_embeddings: nn.Embedding,
layers: Union[nn.Module, List[nn.Module]],
max_seq_len: int,
num_heads: int,
head_dim: int,
norm: nn.Module,
num_layers: Optional[int] = None,
output_hidden_states: Optional[List[int]] = None,
) -> None:
super().__init__()
if num_layers is None:
if isinstance(layers, nn.Module):
raise AssertionError(
"If num_layers is undefined, it is assumed that a list of layers is provided."
)
layers = nn.ModuleList(layers)
else:
if not isinstance(layers, nn.Module):
raise AssertionError("num_layers is defined, layers must be a module")
layers = _get_clones(layers, num_layers)
self.tok_embeddings = tok_embeddings
self.layers = layers
self.norm = norm
self.output_hidden_states = output_hidden_states or []
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0
# attributes for KV caches during inference
self.encoder_max_cache_seq_len = None
self.decoder_max_cache_seq_len = None
@torch.compiler.disable
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
"""
Apply output projection in chunks. This should be applied in conjunction with
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
To use this method, you should first call
:func:`~torchtune.modules.TiedEmbeddingTransformerDecoder.set_num_output_chunks`.
Args:
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
[b, seq_len, embed_dim].
Returns:
List[torch.Tensor]: List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
"""
return [
F.linear(chunk, self.tok_embeddings.weight)
for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1)
]
def set_num_output_chunks(self, num_output_chunks: int) -> None:
"""Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`.
This should be called before the first forward pass, in the recipe."""
self.num_output_chunks = num_output_chunks
def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
*,
encoder_max_seq_len: Optional[int] = None,
decoder_max_seq_len: Optional[int] = None,
):
"""
Sets up key-value attention caches for inference. For each layer in ``self.layers``:
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
- :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length.
decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length.
"""
has_encoder_layers = any(
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
)
if has_encoder_layers:
if encoder_max_seq_len is not None:
self.encoder_max_cache_seq_len = encoder_max_seq_len
else:
self.encoder_max_cache_seq_len = self.max_seq_len
if has_decoder_layers:
if decoder_max_seq_len is not None:
self.decoder_max_cache_seq_len = decoder_max_seq_len
else:
self.decoder_max_cache_seq_len = self.decoder_max_cache_seq_len
for layer in self.layers:
layer.setup_cache(
batch_size,
dtype,
self.encoder_max_cache_seq_len,
self.decoder_max_cache_seq_len,
)
@property
def encoder_caches_are_enabled(self) -> bool:
"""Checks if there are any :class:`~torchtune.modules.TransformerCrossAttentionLayer`,
or :class:`~torchtune.modules.fusion.FusionLayer` layers which have cache enabled.
"""
return self.encoder_max_cache_seq_len is not None
@property
def decoder_caches_are_enabled(self) -> bool:
"""Check if the key value caches are setup."""
return self.decoder_max_cache_seq_len is not None
def reset_caches(self):
"""Reset the key value caches."""
if not (self.encoder_caches_are_enabled or self.decoder_caches_are_enabled):
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)
for layer in self.layers:
layer.reset_cache()
def forward(
self,
tokens: torch.Tensor,
*,
mask: Optional[_MaskType] = None,
encoder_input: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Args:
tokens (torch.Tensor): input tensor with shape [b x s]
mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
and before the softmax. Either a boolean tensor with shape [b x s x s] or a
:class:`~torch.nn.attention.flex_attention.BlockMask`. If a boolean tensor, a value
of True in row i and column j means token i attends to token j. A value of False means
token i does not attend to token j. If no mask is specified, a causal mask
is used by default. If a :class:`~torch.nn.attention.flex_attention.BlockMask` is passed
for document masking in a packed sequence via `create_block_mask
<https://pytorch.org/blog/flexattention/#mask-mods>`_, we use
:func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention.
Default is None.
encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e]
encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between
tokens and encoder embeddings. A True value at position i,j means token i can attend
to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None.
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Note: At the very first step of inference, when the model is provided with a prompt,
``input_pos`` would contain the positions of all of the tokens in the prompt
(eg: ``torch.arange(prompt_length)``). This is because we will need to compute the
KV values for each position.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape [b x s x v] or a list of layer
output tensors defined by ``output_hidden_states`` with the
final output tensor appended to the list.
Raises:
ValueError: if seq_len of x is bigger than max_seq_len
ValueError: if a mask is provided and the model is in inference mode
Notation used for tensor shapes:
- b: batch size
- s: token sequence length
- s_e: encoder sequence length
- v: vocab size
- d: token embed dim
- d_e: encoder embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape
if seq_len > self.max_seq_len:
raise ValueError(
f"seq_len ({seq_len}) of input tensor should be smaller "
f"than max_seq_len ({self.max_seq_len})"
)
# shape: [b, s, d]
h = self.tok_embeddings(tokens)
if self.decoder_caches_are_enabled:
if mask is None:
raise ValueError(
"KV-caches for self-attention layers are setup for inference mode, masks must be provided!"
" Use the `mask` arg to provide a mask."
)
if self.encoder_caches_are_enabled:
if encoder_mask is None:
raise ValueError(
"KV-caches for cross-attention/fusion layers are setup for inference mode, encoder masks must be provided!"
" Use the `encoder_mask` arg to provide an encoder mask."
)
if (
self.encoder_caches_are_enabled
or self.decoder_caches_are_enabled
and input_pos is None
):
raise ValueError(
"KV-caches are setup for inference mode, input positions must be provided!"
)
hidden = []
for i, layer in enumerate(self.layers):
if i in self.output_hidden_states:
hidden.append(h)
# shape: [b, s, d]
h = layer(
h,
mask=mask,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)
# shape: [b, s, d]
h = self.norm(h)
if self.num_output_chunks > 0:
output = self.chunked_output(h)
else:
# shape: [b, seq_len, out_dim]
output = F.linear(h, self.tok_embeddings.weight).float()
# Output list if hidden states are requested, otherwise just the output
# TODO: always output a list to have a consistent output type
output = output if not hidden else [*hidden, output]
return output