Shortcuts

Source code for torchtune.models.qwen2._model_builders

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional

from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType

from torchtune.models.qwen2._component_builders import lora_qwen2, qwen2
from torchtune.models.qwen2._tokenizer import QWEN2_SPECIAL_TOKENS, Qwen2Tokenizer
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json

"""
Model builders build specific instantiations using component builders. For example
the qwen2_7b model builder uses the qwen2 component builder to create the
qwen2 7B model.
"""


[docs]def qwen2_7b() -> TransformerDecoder: """ Builder for creating a Qwen2 model initialized w/ the default 7B parameter values from https://huggingface.co/Qwen/Qwen2-7B-Instruct Returns: TransformerDecoder: Instantiation of Qwen2 7B model """ return qwen2( vocab_size=152064, num_layers=28, num_heads=28, num_kv_heads=4, embed_dim=3584, intermediate_dim=18944, max_seq_len=32768, attn_dropout=0.0, norm_eps=1e-06, rope_base=1000000.0, )
[docs]def qwen2_0_5b() -> TransformerDecoder: """ Builder for creating a Qwen2 model initialized w/ the default 0.5B parameter values from https://huggingface.co/Qwen/Qwen2-0.5B-Instruct Returns: TransformerDecoder: Instantiation of Qwen2 0.5B model Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default and returns an instance of `TransformerDecoder`. """ return qwen2( vocab_size=151936, num_layers=24, num_heads=14, num_kv_heads=2, embed_dim=896, intermediate_dim=4864, max_seq_len=32768, attn_dropout=0.0, norm_eps=1e-06, rope_base=1000000.0, tie_word_embeddings=True, )
[docs]def qwen2_1_5b() -> TransformerDecoder: """ Builder for creating a Qwen2 model initialized w/ the default 1.5B parameter values from https://huggingface.co/Qwen/Qwen2-1.5B-Instruct Returns: TransformerDecoder: Instantiation of Qwen2 1.5B model Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default and returns an instance of `TransformerDecoder`. """ return qwen2( vocab_size=151936, num_layers=28, num_heads=12, num_kv_heads=2, embed_dim=1536, intermediate_dim=8960, max_seq_len=32768, attn_dropout=0.0, norm_eps=1e-06, rope_base=1000000.0, tie_word_embeddings=True, )
[docs]def qwen2_tokenizer( path: str, merges_file: str = None, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None, truncation_type: str = "right", **kwargs, ) -> Qwen2Tokenizer: """ Tokenizer for Qwen2. Args: path (str): path to the vocab.json file. merges_file (str): path to the merges.txt file. special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face model files that contains all registered special tokens, or a local json file structured similarly. Default is None to use the canonical Qwen2 special tokens. max_seq_len (Optional[int]): A max sequence length to truncate tokens to. Default: None prompt_template (Optional[_TemplateType]): optional specified prompt template. If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface` class. If a dictionary, it is assumed to be a custom prompt template mapping role to the prepend/append tags. Default is None. truncation_type (str): type of truncation to apply, either "left" or "right". Default is "right". Returns: Qwen2Tokenizer: Instantiation of the Qwen2 tokenizer """ special_tokens = ( parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else QWEN2_SPECIAL_TOKENS ) template = ( _get_prompt_template(prompt_template) if prompt_template is not None else None ) return Qwen2Tokenizer( path=path, merges_file=merges_file, special_tokens=special_tokens, max_seq_len=max_seq_len, prompt_template=template, truncation_type=truncation_type, **kwargs, )
[docs]def lora_qwen2_7b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: """ Builder for creating a Qwen2 7B model with LoRA enabled. The Qwen2 defaults are the same as in :func:`~torchtune.models.qwen2.qwen2_7b`, while LoRA default params are based on https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: TransformerDecoder: Instantiation of Qwen2 7B model with LoRA applied """ return lora_qwen2( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, vocab_size=152064, num_layers=28, num_heads=28, num_kv_heads=4, embed_dim=3584, intermediate_dim=18944, max_seq_len=32768, attn_dropout=0.0, norm_eps=1e-6, rope_base=1000000.0, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, )
[docs]def lora_qwen2_0_5b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: """ Builder for creating a Qwen2 0.5B model with LoRA enabled. The Qwen2 defaults are the same as in :func:`~torchtune.models.qwen2.qwen2_0_5b`, while LoRA default params are based on https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: TransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default and returns an instance of `TransformerDecoder`. """ return lora_qwen2( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=False, vocab_size=151936, num_layers=24, num_heads=14, num_kv_heads=2, embed_dim=896, intermediate_dim=4864, max_seq_len=32768, attn_dropout=0.0, norm_eps=1e-6, rope_base=1000000.0, tie_word_embeddings=True, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, )
[docs]def lora_qwen2_1_5b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: """ Builder for creating a Qwen2 1.5B model with LoRA enabled. The Qwen2 defaults are the same as in :func:`~torchtune.models.qwen2.qwen2_1_5b`, while LoRA default params are based on https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 quantize_base (bool): Whether to quantize base model weights Returns: TransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default and returns an instance of `TransformerDecoder`. """ return lora_qwen2( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=False, vocab_size=151936, num_layers=28, num_heads=12, num_kv_heads=2, embed_dim=1536, intermediate_dim=8960, max_seq_len=32768, attn_dropout=0.0, norm_eps=1e-6, rope_base=1000000.0, tie_word_embeddings=True, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources