Shortcuts

Source code for torchtune.models.gemma._component_builders

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch import nn
from typing import List
from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks

from torchtune.modules import (
    MultiHeadAttention,
    FeedForward,
    FrozenNF4Linear,
    RotaryPositionalEmbeddings,
    TransformerSelfAttentionLayer,
)

from torchtune.models.gemma.rms_norm import GemmaRMSNorm
from torchtune.modules import TransformerDecoder, TiedLinear
from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear

"""
Component builders for the Gemma 2B models and popular variants such as LoRA.

torchtune provides composable building blocks. Builder functions help
stitch these building blocks into higher-level components. This design has
two benefits:
- The building blocks themselves are very flexible. For example, ``MultiHeadAttention``
can take either nn.Linear or nn.LoRALinear for ``q_proj``.
- Builder functions expose a set of configurable params which keep the constructors of
the building blocks simple.
"""


[docs]def gemma( vocab_size: int, num_layers: int, num_heads: int, head_dim: int, num_kv_heads: int, embed_dim: int, intermediate_dim: int, max_seq_len: int, attn_dropout: float = 0.0, norm_eps: float = 1e-6, rope_base: int = 10_000, norm_embeddings: bool = True, ) -> TransformerDecoder: """ Build the decoder associated with the gemma model. This includes: - Token embeddings - num_layers number of TransformerSelfAttentionLayer blocks - RMS Norm layer applied to the output of the transformer - Final projection into token space This does NOT currently include inference-time optimizations such as sliding-window attention Args: vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value head_dim (int): dimension of head num_kv_heads (int): number of key and value heads. embed_dim (int): embedding dimension for self-attention intermediate_dim (int): intermediate dimension for MLP max_seq_len (int): maximum sequence length the model will be run with, attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 norm_eps (float): epsilon in RMS norms Default: 1e-6 rope_base (int): base for the rotary positional embeddings. Default: 10_000 norm_embeddings (bool): whether to apply layer norm before the self-attention and mlp layers. Default: True Returns: TransformerDecoder: Instantiation of gemma model. """ rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_att = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), pos_embeddings=rope, kv_cache=None, max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) layer = TransformerSelfAttentionLayer( attn=self_att, mlp=mlp, sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), ) tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) model = TransformerDecoder( tok_embeddings=tok_embeddings, layers=layer, num_layers=num_layers, max_seq_len=max_seq_len, num_heads=num_heads, output=output_proj, head_dim=head_dim, norm=GemmaRMSNorm(embed_dim, eps=norm_eps) ) return model
def gemma_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Gemma model. Args: dim (int): input dimension to the MLP hidden_dim (int): hidden dimension of the MLP """ gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) activation = nn.GELU(approximate="tanh") return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation)
[docs]def lora_gemma( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, *, # gemma args vocab_size: int, num_layers: int, num_heads: int, head_dim: int, num_kv_heads: int, embed_dim: int, intermediate_dim: int, max_seq_len: int, attn_dropout: float = 0.0, norm_eps: float = 1e-6, rope_base: int = 10_000, norm_embeddings: bool = True, # LoRA args lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: """ Return a version of Gemma with LoRA applied based on the passed in configuration. Note: output projection lora is not supported because it is tied to token embeddings Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value head_dim (int): dimension of head num_kv_heads (int): number of key and value heads. embed_dim (int): embedding dimension for self-attention intermediate_dim (int): intermediate dimension for MLP max_seq_len (int): maximum sequence length the model will be run with, attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 norm_eps (float): epsilon in RMS norms Default: 1e-6 rope_base (int): base for the rotary positional embeddings. Default: 10_000 norm_embeddings (bool): whether to apply layer norm before the self-attention and mlp layers. Default: True lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Decompose the LoRA weight into magnitude and direction, as introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: TransformerDecoder: Instantiation of Gemma model with LoRA applied to a subset of the attention projections in each layer. """ self_attn = lora_gemma_self_attention( lora_modules=lora_attn_modules, embed_dim=embed_dim, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, max_seq_len=max_seq_len, attn_dropout=attn_dropout, rope_base=rope_base, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) if apply_lora_to_mlp: mlp = lora_gemma_mlp( dim=embed_dim, hidden_dim=intermediate_dim, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) else: mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), ) tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) model = TransformerDecoder( tok_embeddings=tok_embeddings, layers=layer, num_layers=num_layers, max_seq_len=max_seq_len, num_heads=num_heads, output=output_proj, head_dim=head_dim, norm=GemmaRMSNorm(embed_dim, eps=norm_eps) ) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly # so as to not increase peak memory # TODO this is clowny, figure out a better way to get what precision the rest # of the model is in _register_reparametrize_state_dict_hooks(model, dtype=tok_embeddings.weight.dtype) return model
def lora_gemma_self_attention( lora_modules: List[LORA_ATTN_MODULES], *, # MultiHeadAttention args embed_dim: int, num_heads: int, head_dim: int, num_kv_heads: int, max_seq_len: int, attn_dropout: float = 0.0, rope_base: int = 10_000, # LoRA args lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> MultiHeadAttention: if not lora_modules: raise ValueError( f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" ) num_kv_heads = num_kv_heads if num_kv_heads else num_heads adapter_cls = DoRALinear if use_dora else LoRALinear q_proj = ( adapter_cls( embed_dim, num_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "q_proj" in lora_modules else ( nn.Linear(embed_dim, num_heads * head_dim, bias=False) if not quantize_base else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) ) ) k_proj = ( adapter_cls( embed_dim, num_kv_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "k_proj" in lora_modules else ( nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) if not quantize_base else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) ) ) v_proj = ( adapter_cls( embed_dim, num_kv_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "v_proj" in lora_modules else ( nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) if not quantize_base else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) ) ) output_proj = ( adapter_cls( num_heads * head_dim, embed_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) if "output_proj" in lora_modules else ( nn.Linear(num_heads * head_dim, embed_dim, bias=False) if not quantize_base else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False) ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, output_proj=output_proj, pos_embeddings=rope, max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) return self_attn def lora_gemma_mlp( *, dim: int, hidden_dim: int, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> FeedForward: adapter_cls = DoRALinear if use_dora else LoRALinear gate_proj = adapter_cls( in_dim=dim, out_dim=hidden_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) down_proj = adapter_cls( in_dim=hidden_dim, out_dim=dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) up_proj = adapter_cls( in_dim=dim, out_dim=hidden_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, ) activation = nn.GELU(approximate="tanh") return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources