Shortcuts

Source code for torchtune.models.llama2._component_builders

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

from torch import nn

from torchtune.models.llama2._model_utils import scale_hidden_dim_for_mlp

from torchtune.modules import (
    FeedForward,
    FrozenNF4Linear,
    MultiHeadAttention,
    RMSNorm,
    RotaryPositionalEmbeddings,
    TransformerDecoder,
    TransformerSelfAttentionLayer,
)
from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks

from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
from torchtune.utils._logging import deprecated

"""
Component builders for the Llama2 model and popular variants such as LoRA.

torchtune provides composable building blocks. Builder functions help
stitch these building blocks into higher-level components. This design has
two benefits:
- The building blocks themselves are very flexible. For example, ``MultiHeadAttention``
can take either nn.Linear or nn.LoRALinear for ``q_proj``.
- Builder functions expose a set of configurable params which keep the constructors of
the building blocks simple.
"""


# ------------------ Vanilla Llama2 ------------------


[docs]def llama2( vocab_size: int, num_layers: int, num_heads: int, num_kv_heads: int, embed_dim: int, max_seq_len: int, attn_dropout: float = 0.0, intermediate_dim: Optional[int] = None, norm_eps: float = 1e-5, rope_base: float = 10000.0, ) -> TransformerDecoder: """ Build the decoder associated with the Llama2 model. This includes: - Token embeddings - num_layers number of TransformerSelfAttentionLayer blocks - RMS Norm layer applied to the output of the transformer - Final projection into token space Args: vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. embed_dim (int): embedding dimension for self-attention max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` norm_eps (float): epsilon in RMS norms. rope_base (float): base for rotary embeddings. Default: 10000.0 Returns: TransformerDecoder: Instantiation of Llama2 model. """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) rope = RotaryPositionalEmbeddings( dim=head_dim, max_seq_len=max_seq_len, base=rope_base ) layers = nn.ModuleList() for _ in range(num_layers): self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(embed_dim, embed_dim, bias=False), pos_embeddings=rope, kv_cache=None, max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, norm=RMSNorm(embed_dim, eps=norm_eps), output=output_proj, )
def llama2_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ gate_proj = ( nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) ) down_proj = ( nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) ) up_proj = ( nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) ) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) # ------------------ LoRA Llama2 ------------------
[docs]def lora_llama2( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, *, # llama2 args vocab_size: int, num_layers: int, num_heads: int, num_kv_heads: int, embed_dim: int, max_seq_len: int, intermediate_dim: Optional[int] = None, attn_dropout: float = 0.0, norm_eps: float = 1e-5, # LoRA args lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, # Quantization args quantize_base: bool = False, ) -> TransformerDecoder: """ Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`) with LoRA applied based on the passed in configuration. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. Default: False vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. embed_dim (int): embedding dimension for self-attention max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` norm_eps (float): epsilon in RMS norms. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: TransformerDecoder: Instantiation of Llama2 model with LoRA applied to a subset of the attention projections in each layer. """ hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) layers = nn.ModuleList() for _ in range(num_layers): self_attn = lora_llama2_self_attention( lora_modules=lora_attn_modules, embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, max_seq_len=max_seq_len, attn_dropout=attn_dropout, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) if apply_lora_to_mlp: mlp = lora_llama2_mlp( dim=embed_dim, hidden_dim=hidden_dim, lora_rank=lora_rank, lora_alpha=lora_alpha, quantize_base=quantize_base, use_dora=use_dora, lora_dropout=lora_dropout, ) else: mlp = llama2_mlp( dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base ) layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( adapter_cls( embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, ) if apply_lora_to_output else nn.Linear(embed_dim, vocab_size, bias=False) ) model = TransformerDecoder( tok_embeddings=tok_embeddings, layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), norm=RMSNorm(embed_dim, eps=norm_eps), output=output_proj, ) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly # so as to not increase peak memory # TODO this is clowny, figure out a better way to get what precision the rest # of the model is in _register_reparametrize_state_dict_hooks( model, dtype=tok_embeddings.weight.dtype ) return model
def lora_llama2_self_attention( lora_modules: List[LORA_ATTN_MODULES], *, # MultiHeadAttention args embed_dim: int, num_heads: int, num_kv_heads: int, max_seq_len: int, attn_dropout: float = 0.0, # LoRA args lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> MultiHeadAttention: """ Return an instance of :func:`~torchtune.modules.MultiHeadAttention` with LoRA applied to a subset of its linear layers Args: lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. embed_dim (int): embedding dimension for self-attention num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base (bool): Whether to quantize base model parameters for linear layers LoRA is being applied to. Default is ``False``. Returns: MultiHeadAttention: instantiation of self-attention module with LoRA applied to a subset of Q, K, V, output projections. Raises: ValueError: If lora_modules arg is an empty list """ if not lora_modules: raise ValueError( f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" ) head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads adapter_cls = DoRALinear if use_dora else LoRALinear q_proj = ( adapter_cls( embed_dim, num_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "q_proj" in lora_modules else ( nn.Linear(embed_dim, num_heads * head_dim, bias=False) if not quantize_base else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) ) ) k_proj = ( adapter_cls( embed_dim, num_kv_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "k_proj" in lora_modules else ( nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) if not quantize_base else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) ) ) v_proj = ( adapter_cls( embed_dim, num_kv_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "v_proj" in lora_modules else ( nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) if not quantize_base else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) ) ) output_proj = ( adapter_cls( embed_dim, embed_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "output_proj" in lora_modules else ( nn.Linear(embed_dim, embed_dim, bias=False) if not quantize_base else FrozenNF4Linear(embed_dim, embed_dim, bias=False) ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, output_proj=output_proj, pos_embeddings=rope, kv_cache=None, max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) return self_attn def lora_llama2_mlp( *, dim: int, hidden_dim: int, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> FeedForward: adapter_cls = DoRALinear if use_dora else LoRALinear gate_proj = adapter_cls( in_dim=dim, out_dim=hidden_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) down_proj = adapter_cls( in_dim=hidden_dim, out_dim=dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) up_proj = adapter_cls( in_dim=dim, out_dim=hidden_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) return FeedForward( gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, ) # ------------------ Llama2 Classifier ------------------ @deprecated( msg="Model-specific classifier builders are deprecated and will be removed in 0.8.0. " "Please use `torchtune.modules.classifier_model`, with " "`base_model_path=torchtune.models.llama2.llama2` instead." ) def llama2_classifier( num_classes: int, *, vocab_size: int, num_layers: int, num_heads: int, num_kv_heads: int, embed_dim: int, max_seq_len: int, attn_dropout: float = 0.0, intermediate_dim: Optional[int] = None, norm_eps: float = 1e-5, ) -> TransformerDecoder: """ Build a base Llama2 model with the final projection replaced with a classification layer. Args: num_classes (int): number of classes for classification. vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA embed_dim (int): embedding dimension for self-attention max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` norm_eps (float): epsilon in RMS norms. Returns: TransformerDecoder: Instantiation of Llama2 model. """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) layers = nn.ModuleList() for _ in range(num_layers): self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(embed_dim, embed_dim, bias=False), pos_embeddings=rope, kv_cache=None, max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, num_classes, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, norm=RMSNorm(embed_dim, eps=norm_eps), output=output_proj, ) @deprecated( msg="Model-specific classifier builders, and PEFT-based classifier builders with `apply_lora_to_output=True` " " are deprecated and will be removed in 0.8.0. " "Please use `torchtune.modules.classifier_model`, with " "`base_model_path=torchtune.models.llama2.lora_llama2` and " "`apply_lora_to_output=False` instead." ) def lora_llama2_classifier( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, *, # llama2 classifier args, num_classes: int, # llama2 args vocab_size: int, num_layers: int, num_heads: int, num_kv_heads: int, embed_dim: int, max_seq_len: int, intermediate_dim: Optional[int] = None, attn_dropout: float = 0.0, norm_eps: float = 1e-5, # LoRA args lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, # Quantization args quantize_base: bool = False, ) -> TransformerDecoder: """ Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`) with LoRA applied based on the passed in configuration. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. Default: False num_classes (int): number of classes for classification. vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. embed_dim (int): embedding dimension for self-attention max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` norm_eps (float): epsilon in RMS norms. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: TransformerDecoder: Instantiation of Llama2 model with LoRA applied to a subset of the attention projections in each layer. """ hidden_dim = ( intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) ) layers = nn.ModuleList() for _ in range(num_layers): self_attn = lora_llama2_self_attention( lora_modules=lora_attn_modules, embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, max_seq_len=max_seq_len, attn_dropout=attn_dropout, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) if apply_lora_to_mlp: mlp = lora_llama2_mlp( dim=embed_dim, hidden_dim=hidden_dim, lora_rank=lora_rank, lora_alpha=lora_alpha, quantize_base=quantize_base, use_dora=use_dora, lora_dropout=lora_dropout, ) else: mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) layers.append(layer) tok_embeddings = nn.Embedding(vocab_size, embed_dim) # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( adapter_cls( embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, ) if apply_lora_to_output else nn.Linear(embed_dim, num_classes, bias=False) ) model = TransformerDecoder( tok_embeddings=tok_embeddings, layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=(embed_dim // num_heads), norm=RMSNorm(embed_dim, eps=norm_eps), output=output_proj, ) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly # so as to not increase peak memory # TODO this is clowny, figure out a better way to get what precision the rest # of the model is in _register_reparametrize_state_dict_hooks( model, dtype=tok_embeddings.weight.dtype ) return model

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources