Source code for torchtune.models.mistral._component_builders
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.fromfunctoolsimportpartialfromtorchtune.modules.common_utilsimportreparametrize_as_dtype_state_dict_post_hookfromtypingimportListfromtorchimportnnfromtorchtune.modulesimport(CausalSelfAttention,FeedForward,RMSNorm,RotaryPositionalEmbeddings,TransformerDecoder,TransformerDecoderLayer,)fromtorchtune.modules.peftimportLORA_ATTN_MODULES,LoRALinear"""Component builders for the Mistral 7B models and popular variants such as LoRA.torchtune provides composable building blocks. Builder functions helpstitch these building blocks into higher-level components. This design hastwo benefits:- The building blocks themselves are very flexible. For example, ``CausalSelfAttention``can take either nn.Linear or nn.LoRALinear for ``q_proj``.- Builder functions expose a set of configurable params which keep the constructors ofthe building blocks simple."""
[docs]defmistral(vocab_size:int,num_layers:int,num_heads:int,num_kv_heads:int,embed_dim:int,intermediate_dim:int,max_seq_len:int,attn_dropout:float=0.0,norm_eps:float=1e-5,rope_base:int=10_000,)->TransformerDecoder:""" Build the decoder associated with the mistral model. This includes: - Token embeddings - num_layers number of TransformerDecoderLayer blocks - RMS Norm layer applied to the output of the transformer - Final projection into token space This does NOT currently include inference-time optimizations such as sliding-window attention Args: vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA embed_dim (int): embedding dimension for self-attention intermediate_dim (int): intermediate dimension for MLP max_seq_len (int): maximum sequence length the model will be run with, attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 norm_eps (float): epsilon in RMS norms rope_base (int): base for the rotary positional embeddings. Default: 10_000 Returns: TransformerDecoder: Instantiation of mistral model. """head_dim=embed_dim//num_headsnum_kv_heads=num_kv_headsifnum_kv_headselsenum_headsrope=RotaryPositionalEmbeddings(dim=head_dim,max_seq_len=max_seq_len,base=rope_base)self_attn=CausalSelfAttention(embed_dim=embed_dim,num_heads=num_heads,num_kv_heads=num_kv_heads,head_dim=head_dim,q_proj=nn.Linear(embed_dim,num_heads*head_dim,bias=False),k_proj=nn.Linear(embed_dim,num_kv_heads*head_dim,bias=False),v_proj=nn.Linear(embed_dim,num_kv_heads*head_dim,bias=False),output_proj=nn.Linear(embed_dim,embed_dim,bias=False),pos_embeddings=rope,kv_cache=None,max_seq_len=max_seq_len,attn_dropout=attn_dropout,)mlp=mistral_mlp(dim=embed_dim,hidden_dim=intermediate_dim)layer=TransformerDecoderLayer(attn=self_attn,mlp=mlp,sa_norm=RMSNorm(dim=embed_dim,eps=norm_eps),mlp_norm=RMSNorm(dim=embed_dim,eps=norm_eps),)tok_embeddings=nn.Embedding(vocab_size,embed_dim)output_proj=nn.Linear(embed_dim,vocab_size,bias=False)returnTransformerDecoder(tok_embeddings=tok_embeddings,layer=layer,num_layers=num_layers,max_seq_len=max_seq_len,num_heads=num_heads,head_dim=head_dim,norm=RMSNorm(embed_dim,eps=norm_eps),output=output_proj,)
defmistral_mlp(dim:int,hidden_dim:int)->FeedForward:""" Build the MLP layer associated with the Mistral model. """gate_proj=nn.Linear(dim,hidden_dim,bias=False)down_proj=nn.Linear(hidden_dim,dim,bias=False)up_proj=nn.Linear(dim,hidden_dim,bias=False)returnFeedForward(gate_proj=gate_proj,down_proj=down_proj,up_proj=up_proj)
[docs]deflora_mistral(lora_attn_modules:List[LORA_ATTN_MODULES],apply_lora_to_mlp:bool=False,apply_lora_to_output:bool=False,*,# mistral argsvocab_size:int,num_layers:int,num_heads:int,num_kv_heads:int,embed_dim:int,max_seq_len:int,intermediate_dim:int,attn_dropout:float=0.0,norm_eps:float=1e-5,rope_base:int=10_000,# LoRA argslora_rank:int,lora_alpha:float,lora_dropout:float=0.0,quantize_base:bool=False,)->TransformerDecoder:""" Return a version of Mistral (an instance of :func:`~torchtune.modules.TransformerDecoder`) with LoRA applied based on the passed in configuration. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. Default: False vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA embed_dim (int): embedding dimension for self-attention max_seq_len (int): maximum sequence length the model will be run with intermediate_dim (int): intermediate dimension for MLP. attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 norm_eps (float): epsilon in RMS norms. rope_base (int): base for the rotary positional embeddings. Default: 10_000 lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: TransformerDecoder: Instantiation of Mistral model with LoRA applied to a subset of the attention projections in each layer. """self_attn=lora_mistral_self_attention(lora_modules=lora_attn_modules,embed_dim=embed_dim,num_heads=num_heads,num_kv_heads=num_kv_heads,max_seq_len=max_seq_len,attn_dropout=attn_dropout,rope_base=rope_base,lora_rank=lora_rank,lora_alpha=lora_alpha,lora_dropout=lora_dropout,quantize_base=quantize_base,)ifapply_lora_to_mlp:mlp=lora_mistral_mlp(dim=embed_dim,hidden_dim=intermediate_dim,lora_rank=lora_rank,lora_alpha=lora_alpha,lora_dropout=lora_dropout,quantize_base=quantize_base,)else:mlp=mistral_mlp(dim=embed_dim,hidden_dim=intermediate_dim)layer=TransformerDecoderLayer(attn=self_attn,mlp=mlp,sa_norm=RMSNorm(dim=embed_dim,eps=norm_eps),mlp_norm=RMSNorm(dim=embed_dim,eps=norm_eps),)tok_embeddings=nn.Embedding(vocab_size,embed_dim)# TODO: quantize_base is not applied to final output_proj currently.output_proj=(LoRALinear(embed_dim,vocab_size,rank=lora_rank,alpha=lora_alpha)ifapply_lora_to_outputelsenn.Linear(embed_dim,vocab_size,bias=False))model=TransformerDecoder(tok_embeddings=tok_embeddings,layer=layer,num_layers=num_layers,max_seq_len=max_seq_len,num_heads=num_heads,head_dim=(embed_dim//num_heads),norm=RMSNorm(embed_dim,eps=norm_eps),output=output_proj,)ifquantize_base:# For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly# so as to not increase peak memorymodel._register_state_dict_hook(partial(reparametrize_as_dtype_state_dict_post_hook,# TODO this is clowny, figure out a better way to get what precision the rest# of the model is indtype=tok_embeddings.weight.dtype,offload_to_cpu=True,))returnmodel
deflora_mistral_self_attention(lora_modules:List[LORA_ATTN_MODULES],*,# CausalSelfAttention argsembed_dim:int,num_heads:int,num_kv_heads:int,max_seq_len:int,attn_dropout:float=0.0,rope_base:int=10_000,# LoRA argslora_rank:int,lora_alpha:float,lora_dropout:float=0.0,quantize_base:bool=False,)->CausalSelfAttention:""" Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA applied to a subset of its linear layers Args: lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. embed_dim (int): embedding dimension for self-attention num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA max_seq_len (int): maximum sequence length the model will be run with attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 rope_base (int): base for the rotary positional embeddings. Default: 10_000 lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 quantize_base (bool): Whether to quantize base model parameters for linear layers LoRA is being applied to. Default is ``False``. Returns: CausalSelfAttention: instantiation of self-attention module with LoRA applied to a subset of Q, K, V, output projections. Raises: ValueError: If lora_modules arg is an empty list """ifnotlora_modules:raiseValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules")head_dim=embed_dim//num_headsnum_kv_heads=num_kv_headsifnum_kv_headselsenum_headsq_proj=(LoRALinear(embed_dim,num_heads*head_dim,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout,quantize_base=quantize_base,)if"q_proj"inlora_moduleselsenn.Linear(embed_dim,num_heads*head_dim,bias=False))k_proj=(LoRALinear(embed_dim,num_kv_heads*head_dim,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout,quantize_base=quantize_base,)if"k_proj"inlora_moduleselsenn.Linear(embed_dim,num_kv_heads*head_dim,bias=False))v_proj=(LoRALinear(embed_dim,num_kv_heads*head_dim,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout,quantize_base=quantize_base,)if"v_proj"inlora_moduleselsenn.Linear(embed_dim,num_kv_heads*head_dim,bias=False))output_proj=(LoRALinear(embed_dim,embed_dim,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout,quantize_base=quantize_base,)if"output_proj"inlora_moduleselsenn.Linear(embed_dim,embed_dim,bias=False))rope=RotaryPositionalEmbeddings(dim=head_dim,max_seq_len=max_seq_len,base=rope_base)self_attn=CausalSelfAttention(embed_dim=embed_dim,num_heads=num_heads,num_kv_heads=num_kv_heads,head_dim=head_dim,q_proj=q_proj,k_proj=k_proj,v_proj=v_proj,output_proj=output_proj,pos_embeddings=rope,max_seq_len=max_seq_len,attn_dropout=attn_dropout,)returnself_attndeflora_mistral_mlp(*,dim:int,hidden_dim:int,lora_rank:int,lora_alpha:float,lora_dropout:float=0.0,quantize_base:bool=False,)->FeedForward:gate_proj=LoRALinear(in_dim=dim,out_dim=hidden_dim,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout,quantize_base=quantize_base,)down_proj=LoRALinear(in_dim=hidden_dim,out_dim=dim,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout,quantize_base=quantize_base,)up_proj=LoRALinear(in_dim=dim,out_dim=hidden_dim,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout,quantize_base=quantize_base,)returnFeedForward(gate_proj=gate_proj,down_proj=down_proj,up_proj=up_proj,)
[docs]defmistral_classifier(num_classes:int,*,# base mistral argsvocab_size:int,num_layers:int,num_heads:int,num_kv_heads:int,embed_dim:int,intermediate_dim:int,max_seq_len:int,attn_dropout:float=0.0,norm_eps:float=1e-5,rope_base:int=10_000,)->TransformerDecoder:""" Build a base mistral model with an added classification layer. See :func:`~torchtune.models.mistral.mistral_classifier` for details on the base mistral classifier model. Args: num_classes (int): number of classes for the classification layer. vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA embed_dim (int): embedding dimension for self-attention intermediate_dim (int): intermediate dimension for MLP max_seq_len (int): maximum sequence length the model will be run with, attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 norm_eps (float): epsilon in RMS norms rope_base (int): base for the rotary positional embeddings. Default: 10_000 Returns: TransformerDecoder: Instantiation of mistral classification model. """head_dim=embed_dim//num_headsnum_kv_heads=num_kv_headsifnum_kv_headselsenum_headsrope=RotaryPositionalEmbeddings(dim=head_dim,max_seq_len=max_seq_len,base=rope_base)self_attn=CausalSelfAttention(embed_dim=embed_dim,num_heads=num_heads,num_kv_heads=num_kv_heads,head_dim=head_dim,q_proj=nn.Linear(embed_dim,num_heads*head_dim,bias=False),k_proj=nn.Linear(embed_dim,num_kv_heads*head_dim,bias=False),v_proj=nn.Linear(embed_dim,num_kv_heads*head_dim,bias=False),output_proj=nn.Linear(embed_dim,embed_dim,bias=False),pos_embeddings=rope,kv_cache=None,max_seq_len=max_seq_len,attn_dropout=attn_dropout,)mlp=mistral_mlp(dim=embed_dim,hidden_dim=intermediate_dim)layer=TransformerDecoderLayer(attn=self_attn,mlp=mlp,sa_norm=RMSNorm(dim=embed_dim,eps=norm_eps),mlp_norm=RMSNorm(dim=embed_dim,eps=norm_eps),)tok_embeddings=nn.Embedding(vocab_size,embed_dim)output_proj=nn.Linear(embed_dim,num_classes,bias=False)returnTransformerDecoder(tok_embeddings=tok_embeddings,layer=layer,num_layers=num_layers,max_seq_len=max_seq_len,num_heads=num_heads,head_dim=head_dim,norm=RMSNorm(embed_dim,eps=norm_eps),output=output_proj,)
[docs]deflora_mistral_classifier(lora_attn_modules:List[LORA_ATTN_MODULES],apply_lora_to_mlp:bool=False,apply_lora_to_output:bool=False,*,# mistral classifier argsnum_classes:int,# mistral argsvocab_size:int,num_layers:int,num_heads:int,num_kv_heads:int,embed_dim:int,max_seq_len:int,intermediate_dim:int,attn_dropout:float=0.0,norm_eps:float=1e-5,rope_base:int=10_000,# LoRA argslora_rank:int,lora_alpha:float,lora_dropout:float=0.0,quantize_base:bool=False,)->TransformerDecoder:""" Return a version of Mistral classifier (an instance of :func:`~torchtune.modules.TransformerDecoder`) with LoRA applied to some of the linear layers in its self-attention modules. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. Default: False vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA embed_dim (int): embedding dimension for self-attention max_seq_len (int): maximum sequence length the model will be run with intermediate_dim (int): intermediate dimension for MLP. attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 norm_eps (float): epsilon in RMS norms. rope_base (int): base for the rotary positional embeddings. Default: 10_000 lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: TransformerDecoder: Instantiation of Mistral classifier model with LoRA applied to a subset of the attention projections in each layer. """self_attn=lora_mistral_self_attention(lora_modules=lora_attn_modules,embed_dim=embed_dim,num_heads=num_heads,num_kv_heads=num_kv_heads,max_seq_len=max_seq_len,attn_dropout=attn_dropout,rope_base=rope_base,lora_rank=lora_rank,lora_alpha=lora_alpha,lora_dropout=lora_dropout,quantize_base=quantize_base,)ifapply_lora_to_mlp:mlp=lora_mistral_mlp(dim=embed_dim,hidden_dim=intermediate_dim,lora_rank=lora_rank,lora_alpha=lora_alpha,lora_dropout=lora_dropout,quantize_base=quantize_base,)else:mlp=mistral_mlp(dim=embed_dim,hidden_dim=intermediate_dim)layer=TransformerDecoderLayer(attn=self_attn,mlp=mlp,sa_norm=RMSNorm(dim=embed_dim,eps=norm_eps),mlp_norm=RMSNorm(dim=embed_dim,eps=norm_eps),)tok_embeddings=nn.Embedding(vocab_size,embed_dim)# TODO: quantize_base is not applied to final output_proj currently.output_proj=(LoRALinear(embed_dim,num_classes,rank=lora_rank,alpha=lora_alpha,dropout=lora_dropout)ifapply_lora_to_outputelsenn.Linear(embed_dim,num_classes,bias=False))model=TransformerDecoder(tok_embeddings=tok_embeddings,layer=layer,num_layers=num_layers,max_seq_len=max_seq_len,num_heads=num_heads,head_dim=(embed_dim//num_heads),norm=RMSNorm(embed_dim,eps=norm_eps),output=output_proj,)ifquantize_base:# For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly# so as to not increase peak memorymodel._register_state_dict_hook(partial(reparametrize_as_dtype_state_dict_post_hook,# TODO this is clowny, figure out a better way to get what precision the rest# of the model is indtype=tok_embeddings.weight.dtype,offload_to_cpu=True,))returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.