Shortcuts

Source code for torchtune.modules.transformer

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Optional

import torch
from torch import nn, Tensor

from torchtune.modules import CausalSelfAttention, KVCache


[docs]class TransformerDecoderLayer(nn.Module): """Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer. Args: attn (CausalSelfAttention): Attention module. mlp (nn.Module): Feed-forward module. sa_norm (nn.Module): Normalization to be applied before self-attention. mlp_norm (nn.Module): Normalization to be applied before the feed-forward layer. """ def __init__( self, attn: CausalSelfAttention, mlp: nn.Module, sa_norm: nn.Module, mlp_norm: nn.Module, ) -> None: super().__init__() self.sa_norm = sa_norm self.attn = attn self.mlp_norm = mlp_norm self.mlp = mlp
[docs] def forward( self, x: Tensor, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, ) -> Tensor: """ Args: x (Tensor): input tensor with shape [batch_size x seq_length x embed_dim] mask (Optional[Tensor]): Optional tensor which contains the mask. Only used during inference. Default is None. input_pos (Optional[Tensor]): Optional tensor which contains the position of the current token. This is only used during inference. Default is None Returns: Tensor: output tensor with same shape as input [batch_size x seq_length x embed_dim] Notation used for tensor shapes: - b: batch size - s: sequence length - d: embed dim TODO: - Make position of norm configurable """ # Input tensor and attention output have the same shape # [b, s, d] # Norm applied before self-attention attn_out = self.attn(self.sa_norm(x), mask, input_pos) # Residual connection; shape: [b, s, d] h = attn_out + x # Norm applied before the feedforward layer mlp_out = self.mlp(self.mlp_norm(h)) # Residual connection; shape: [b, s, d] out = h + mlp_out return out
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList: """ Return a list of ``n`` identical layers. Args: module (nn.Module): module to be cloned n (int): number of clones Returns: nn.ModuleList: list of ``n`` identical layers """ # FIXME: copy.deepcopy() is not defined on nn.module return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
[docs]class TransformerDecoder(nn.Module): """ Transformer Decoder derived from the Llama2 architecture. Args: tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move tokens to an embedding space. layer (TransformerDecoderLayer): Transformer Decoder layer. num_layers (int): Number of Transformer Decoder layers. max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` num_heads (int): number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the :func:`~torchtune.modules.KVCache` head_dim (int): embedding dimension for each head in self-attention. This is used to setup the :func:`~torchtune.modules.KVCache` norm (nn.Module): Callable that applies normalization to the output of the decoder, before final MLP. output (nn.Linear): Callable that applies a linear transformation to the output of the decoder. Note: Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability. """ def __init__( self, tok_embeddings: nn.Embedding, layer: TransformerDecoderLayer, num_layers: int, max_seq_len: int, num_heads: int, head_dim: int, norm: nn.Module, output: nn.Linear, ) -> None: super().__init__() self.tok_embeddings = tok_embeddings self.layers = _get_clones(layer, num_layers) self.norm = norm self.output = output self.max_seq_len = max_seq_len self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None def setup_caches(self, max_batch_size: int, dtype: torch.dtype) -> None: for layer in self.layers: layer.attn.kv_cache = KVCache( max_batch_size=max_batch_size, max_seq_len=self.max_seq_len, num_heads=self.num_heads, head_dim=self.head_dim, dtype=dtype, ) # causal_mask is used during inference to ensure we're attending # to the right tokens self.causal_mask = torch.tril( torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool) )
[docs] def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: """ Args: tokens (Tensor): input tensor with shape [b x s] input_pos (Optional[Tensor]): Optional tensor which contains the position of the current token. This is only used during inference. Default is None Note: At the very first step of inference, when the model is provided with a prompt, ``input_pos`` would contain the positions of all of the tokens in the prompt (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the KV values for each position. Returns: Tensor: output tensor with shape [b x s x v] Raises: ValueError: if causal_mask is set but input_pos is None Notation used for tensor shapes: - b: batch size - s: sequence length - v: vocab size - d: embed dim - m_s: max seq len """ # input tensor of shape [b, s] bsz, seq_len = tokens.shape # shape: [b, s, d] h = self.tok_embeddings(tokens) mask = None if self.causal_mask is not None: if input_pos is None: raise ValueError( "Caches are setup, but the position of input token is missing" ) # shape: [1, input_pos_len, m_s] # in most cases input_pos_len should be 1 mask = self.causal_mask[None, None, input_pos] for layer in self.layers: # shape: [b, s, d] h = layer(h, mask, input_pos) # shape: [b, s, d] h = self.norm(h) # shape: [b, s, v] output = self.output(h).float() return output

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources