Source code for torchtune.models.llama3_2._model_builders
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from functools import partial
from torchtune.models.llama3_2._component_builders import llama3_2, lora_llama3_2
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
"""
Model builders build specific instantiations using component builders. For example
the llama3_2_1b model builder uses the llama3_2 component builder to create the
Llama3.2 1B model.
"""
[docs]def llama3_2_1b() -> TransformerDecoder:
"""
Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values.
Returns:
TransformerDecoder: Instantiation of Llama3.2 1B model
"""
return llama3_2(
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
[docs]def llama3_2_3b() -> TransformerDecoder:
"""
Builder for creating a Llama3.2 model initialized w/ the default 3b parameter values.
Returns:
TransformerDecoder: Instantiation of Llama3.2 3B model
"""
return llama3_2(
vocab_size=128_256,
num_layers=28,
num_heads=24,
num_kv_heads=8,
embed_dim=3072,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
[docs]def lora_llama3_2_1b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Builder for creating a Llama3.2 1B model with LoRA enabled.
The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_1b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): dropout probability for the low-rank approximation
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
quantize_base (bool): Whether to quantize base model weights
Returns:
TransformerDecoder: Instantiation of Llama3.2 1B model with LoRA applied
"""
return lora_llama3_2(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
[docs]def lora_llama3_2_3b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Builder for creating a Llama3.2 3B model with LoRA enabled.
The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_3b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): dropout probability for the low-rank approximation
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
quantize_base (bool): Whether to quantize base model weights
Returns:
TransformerDecoder: Instantiation of Llama3.2 3B model with LoRA applied
"""
return lora_llama3_2(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=128_256,
num_layers=28,
num_heads=24,
num_kv_heads=8,
embed_dim=3072,
max_seq_len=131072,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
qlora_llama3_2_1b = partial(lora_llama3_2_1b, quantize_base=True)
qlora_llama3_2_1b.__doc__ = """
Builder for creating a Llama3.2 1B model with QLoRA enabled. Base model weights in linear layers
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_llama3_2_1b` for full API arguments.
"""
qlora_llama3_2_3b = partial(lora_llama3_2_3b, quantize_base=True)
qlora_llama3_2_3b.__doc__ = """
Builder for creating a Llama3.2 3B model with QLoRA enabled. Base model weights in linear layers
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_llama3_2_3b` for full API arguments.
"""