Shortcuts

Source code for torchtune.models.llama3_2_vision._component_builders

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum
from functools import partial
from typing import List, Optional

from torch import nn
from torchtune.models.clip._component_builders import (
    clip_mlp,
    clip_vision_encoder,
    lora_clip_attention,
    lora_clip_mlp,
    lora_clip_vision_encoder,
)

from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp
from torchtune.models.llama3_1._component_builders import (
    llama3_mlp,
    lora_llama3_attention,
    lora_llama3_mlp,
)
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
from torchtune.models.llama3_2_vision._encoder import (
    Llama3VisionEncoder,
    Llama3VisionProjectionHead,
)
from torchtune.modules import (
    Fp32LayerNorm,
    MultiHeadAttention,
    RMSNorm,
    TanhGate,
    TransformerCrossAttentionLayer,
    TransformerDecoder,
    TransformerSelfAttentionLayer,
)

from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook

from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer

from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear


"""
Component builders for the Llama 3.2 Vision model and its constituent models.
torchtune provides composable building blocks. Builder functions help
stitch these building blocks into higher-level components. This design has
two benefits:
- The building blocks themselves are very flexible. For example, ``GroupedQueryAttention``
can take either nn.Linear or nn.LoRALinear for ``q_proj``.
- Builder functions expose a set of configurable params which keep the constructors of
the building blocks simple.
"""


[docs]def llama3_2_vision_encoder( # clip encoder parameters *, patch_size: int, num_heads: int, clip_embed_dim: int, clip_num_layers: int, clip_hidden_states: Optional[List[int]], # projection parameters num_layers_projection: int, decoder_embed_dim: int, # image parameters tile_size: int, max_num_tiles: int = 4, in_channels: int = 3, ) -> Llama3VisionEncoder: """ Build the Llama 3.2 vision encoder by combining the CLIP image model with an additional projection head fusion module. This includes: - Spatial positional encodings - CLIP model backbone - Projection head on top of CLIP - Final projection into token embedding dimension Args: patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. num_heads (int): The number of attention heads in each transformer layer. clip_embed_dim (int): The dimensionality of each patch embedding in CLIP. clip_num_layers (int): The number of transformer layers. clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return to return to the encoder projection head. It will return the intermediate results of the vision transformer layers which will be concatenated with the CLIP output and input into the projection head. For example, ``clip_hidden_states=[0,3]`` will return the embeddings before they go through the first and fourth layers. num_layers_projection (int): The number of transformer layers in the projection head. decoder_embed_dim (int): The dimensionality of the final output embeddings for the decoder. tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the input image. In this case, the function will consider your image as a single tile. max_num_tiles (int): The maximum number of tiles that can be processed. This is used to determine the size of the positional embeddings. in_channels (int): The number of image input channels. Returns: Llama3VisionEncoder: Instantiation of Llama 3.2 vision encoder. """ # clip encoder clip = clip_vision_encoder( tile_size=tile_size, patch_size=patch_size, embed_dim=clip_embed_dim, num_layers=clip_num_layers, num_heads=num_heads, activation=nn.GELU, out_indices=clip_hidden_states, max_num_tiles=max_num_tiles, in_channels=in_channels, attn_bias=False, output_cls_projection=False, ) # Projection head projection_head = llama3_2_vision_projection_head( num_layers=num_layers_projection, num_heads=num_heads, decoder_embed_dim=decoder_embed_dim, clip_embed_dim=clip_embed_dim, num_hidden_inputs=len(clip_hidden_states or []), ) return Llama3VisionEncoder(clip=clip, projection_head=projection_head)
[docs]def llama3_2_vision_decoder( *, vocab_size: int, num_layers: int, fusion_interval: int, num_special_tokens: int, num_heads: int, num_kv_heads: int, embed_dim: int, max_seq_len: int, encoder_max_seq_len: int, rope_base: int = 500000.0, intermediate_dim: Optional[int] = None, ) -> TransformerDecoder: """ Build the decoder associated with the Llama3 model with additional fused cross attention layers. This includes: - Token embeddings - num_layers number of CausalSelfAttention blocks - Fused cross attention layers every fusion_interval number of layers - RMS Norm layer applied to the output of the transformer - Final projection into token space Args: vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. fusion_interval (int): interval number of layers between fusion layers. num_special_tokens (int): number of special tokens added for the fusion model. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value. num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. embed_dim (int): embedding dimension for self-attention. max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache`. encoder_max_seq_len (int): maximum sequence length the encoder will be run with, as used by :func:`~torchtune.modules.KVCache`. intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`. Returns: TransformerDecoder: Instantiation of Llama 3.2 vision decoder. """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads hidden_dim = intermediate_dim or scale_hidden_dim_for_mlp(embed_dim) layers = [] rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) for idx in range(1, num_layers + 1): # Self attention layers for text decoder self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(embed_dim, embed_dim, bias=False), pos_embeddings=rope, max_seq_len=max_seq_len, attn_dropout=0.0, ) mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) decoder_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=RMSNorm(dim=embed_dim, eps=1e-5), mlp_norm=RMSNorm(dim=embed_dim, eps=1e-5), ) # cross attention layers, mixing text and vision, # placed every `fusion_interval` layers if idx % fusion_interval == 0: attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(embed_dim, embed_dim, bias=False), q_norm=RMSNorm(dim=head_dim, eps=1e-05), k_norm=RMSNorm(dim=head_dim, eps=1e-05), pos_embeddings=None, max_seq_len=encoder_max_seq_len, is_causal=False, attn_dropout=0.0, ) mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) xattn_layer = TransformerCrossAttentionLayer( attn=attn, mlp=mlp, ca_norm=RMSNorm(dim=embed_dim), mlp_norm=RMSNorm(dim=embed_dim), ca_scale=TanhGate(), mlp_scale=TanhGate(), ) fusion_layer = FusionLayer(layer=decoder_layer, fusion_layer=xattn_layer) layers.append(fusion_layer) else: layers.append(decoder_layer) tok_embeddings = FusionEmbedding(vocab_size, num_special_tokens, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( tok_embeddings=tok_embeddings, layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, norm=RMSNorm(embed_dim, eps=1e-05), output=output_proj, )
def llama3_2_vision_projection_head( *, num_layers: int, num_heads: int, decoder_embed_dim: int, clip_embed_dim: int, num_hidden_inputs: int, ) -> Llama3VisionProjectionHead: """ Build the Llama 3.2 Vision Projection Head that maps the output of the CLIP encoder to the decoder cross attention input. Args: num_layers (int): number of layers in the projection head. num_heads (int): number of heads in the projection head. decoder_embed_dim (int): embedding dimension for the decoder. clip_embed_dim (int): embedding dimension for the CLIP encoder. num_hidden_inputs (int): number of hidden inputs to the projection head. Returns: Llama3VisionProjectionHead: Instantiation of Llama 3.2 vision projection head. """ mlp_ratio = 4 hidden_dim = int(mlp_ratio * clip_embed_dim) head_dim = clip_embed_dim // num_heads num_kv_heads = num_heads layers = [] for _ in range(num_layers): self_attn = MultiHeadAttention( embed_dim=clip_embed_dim, num_heads=num_heads, num_kv_heads=num_heads, head_dim=head_dim, q_proj=nn.Linear(clip_embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(clip_embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(clip_embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(clip_embed_dim, clip_embed_dim, bias=False), pos_embeddings=None, attn_dropout=0.0, is_causal=False, ) mlp = clip_mlp( in_dim=clip_embed_dim, hidden_dim=hidden_dim, out_dim=clip_embed_dim, activation=nn.GELU(), ) layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), mlp_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), sa_scale=TanhGate(), mlp_scale=TanhGate(), ) layers.append(layer) # we concatenate clip embeddings and hidden layers output # and project it to embed_dim_out, which will be used for the # cross encoding proj_in = clip_embed_dim * (num_hidden_inputs + 1) return Llama3VisionProjectionHead( layers=layers, output=nn.Linear(proj_in, decoder_embed_dim), num_hidden_inputs=num_hidden_inputs, ) # ------------------ LoRA Llama 3.2 Vision ------------------ class LoRATrainable(Enum): FULL = "full" LORA = "lora" FROZEN = "frozen"
[docs]def lora_llama3_2_vision_encoder( encoder_lora: bool, fusion_lora: bool, lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, *, # clip encoder parameters patch_size: int, num_heads: int, clip_embed_dim: int, clip_num_layers: int, clip_hidden_states: Optional[List[int]], # projection parameters num_layers_projection: int, decoder_embed_dim: int, # image parameters tile_size: int, max_num_tiles: int = 4, in_channels: int = 3, # LoRA parameters lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, **quantization_kwargs, ) -> Llama3VisionEncoder: """ Build the Llama 3.2 vision encoder by combining the CLIP image model with an additional projection head fusion module. This includes: - Spatial positional encodings - CLIP model backbone - Projection head on top of CLIP - Final projection into token embedding dimension Args: encoder_lora (bool): whether to apply LoRA to the CLIP encoder fusion_lora (bool): whether to apply LoRA to the projection head lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's decoder and encoder output projection. Default: False patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. num_heads (int): The number of attention heads in each transformer layer. clip_embed_dim (int): The dimensionality of each patch embedding in CLIP. clip_num_layers (int): The number of transformer layers. clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return to return to the encoder projection head. It will return the intermediate results of the vision transformer layers which will be concatenated with the CLIP output and input into the projection head. For example, ``clip_hidden_states=[0,3]`` will return the embeddings before they go through the first and fourth layers. num_layers_projection (int): The number of transformer layers in the projection head. decoder_embed_dim (int): The dimensionality of the final output embeddings for the decoder. tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the input image. In this case, the function will consider your image as a single tile. max_num_tiles (int): The maximum number of tiles that can be processed. This is used to determine the size of the positional embeddings. in_channels (int): The number of image input channels. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: Llama3VisionEncoder: Instantiation of Llama 3.2 vision encoder. """ lora_options = { "lora_modules": lora_attn_modules, "apply_lora_to_mlp": apply_lora_to_mlp, "lora_rank": lora_rank, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout, "use_dora": use_dora, "quantize_base": quantize_base, **quantization_kwargs, } # clip encoder clip_options = { "tile_size": tile_size, "patch_size": patch_size, "embed_dim": clip_embed_dim, "num_layers": clip_num_layers, "num_heads": num_heads, "activation": nn.GELU, "out_indices": clip_hidden_states, "max_num_tiles": max_num_tiles, "in_channels": in_channels, "attn_bias": False, "output_cls_projection": False, } if encoder_lora: clip = lora_clip_vision_encoder(**clip_options, **lora_options) else: clip = clip_vision_encoder(**clip_options) # Projection projection_options = { "num_layers": num_layers_projection, "num_heads": num_heads, "decoder_embed_dim": decoder_embed_dim, "clip_embed_dim": clip_embed_dim, "num_hidden_inputs": len(clip_hidden_states or []), } if fusion_lora: projection_head = lora_llama3_2_vision_projection_head( apply_lora_to_output=apply_lora_to_output, **projection_options, **lora_options, ) else: projection_head = llama3_2_vision_projection_head(**projection_options) encoder = Llama3VisionEncoder(clip=clip, projection_head=projection_head) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly # so as to not increase peak memory encoder._register_state_dict_hook( partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) ) return encoder
[docs]def lora_llama3_2_vision_decoder( decoder_lora: bool, fusion_lora: bool, lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, *, # decoder params vocab_size: int, num_layers: int, fusion_interval: int, num_special_tokens: int, num_heads: int, num_kv_heads: int, embed_dim: int, max_seq_len: int, encoder_max_seq_len: int, rope_base: int = 500000.0, intermediate_dim: Optional[int] = None, # LoRA parameters lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: """ Build the decoder associated with the Llama3 model with additional fused cross attention layers. This includes: - Token embeddings - num_layers number of CausalSelfAttention blocks - Fused cross attention layers every fusion_interval number of layers - RMS Norm layer applied to the output of the transformer - Final projection into token space Args: decoder_lora (bool): whether to apply LoRA to the language decoder fusion_lora (bool): whether to apply LoRA to the projection head lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. Default: False vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. fusion_interval (int): interval number of layers between fusion layers. num_special_tokens (int): number of special tokens added for the fusion model. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value. num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. embed_dim (int): embedding dimension for self-attention. max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache`. encoder_max_seq_len (int): maximum sequence length the encoder will be run with, as used by :func:`~torchtune.modules.KVCache`. intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: TransformerDecoder: Instantiation of Llama 3.2 vision decoder. """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads hidden_dim = intermediate_dim or scale_hidden_dim_for_mlp(embed_dim) layers = [] rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) for idx in range(1, num_layers + 1): # Self attention layers for text decoder if decoder_lora: self_attn = lora_llama3_attention( lora_modules=lora_attn_modules, pos_embeddings=rope, head_dim=head_dim, embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, max_seq_len=max_seq_len, attn_dropout=0.0, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) else: self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(embed_dim, embed_dim, bias=False), pos_embeddings=rope, max_seq_len=max_seq_len, attn_dropout=0.0, ) if apply_lora_to_mlp and decoder_lora: mlp = lora_llama3_mlp( dim=embed_dim, hidden_dim=hidden_dim, lora_rank=lora_rank, lora_alpha=lora_alpha, quantize_base=quantize_base, lora_dropout=lora_dropout, use_dora=use_dora, ) else: mlp = llama3_mlp( dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base ) decoder_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=RMSNorm(dim=embed_dim, eps=1e-5), mlp_norm=RMSNorm(dim=embed_dim, eps=1e-5), ) # cross attention layers, mixing text and vision, # placed every `fusion_interval` layers if idx % fusion_interval == 0: if fusion_lora: attn = lora_llama3_attention( lora_modules=lora_attn_modules, pos_embeddings=None, head_dim=head_dim, embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, q_norm=RMSNorm(dim=head_dim, eps=1e-05), k_norm=RMSNorm(dim=head_dim, eps=1e-05), max_seq_len=encoder_max_seq_len, is_causal=False, attn_dropout=0.0, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, ) else: attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(embed_dim, embed_dim, bias=False), q_norm=RMSNorm(dim=head_dim, eps=1e-05), k_norm=RMSNorm(dim=head_dim, eps=1e-05), pos_embeddings=None, max_seq_len=encoder_max_seq_len, is_causal=False, attn_dropout=0.0, ) if apply_lora_to_mlp and fusion_lora: mlp = lora_llama3_mlp( dim=embed_dim, hidden_dim=hidden_dim, lora_rank=lora_rank, lora_alpha=lora_alpha, quantize_base=quantize_base, lora_dropout=lora_dropout, use_dora=use_dora, ) else: mlp = llama3_mlp( dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base ) xattn_layer = TransformerCrossAttentionLayer( attn=attn, mlp=mlp, ca_norm=RMSNorm(dim=embed_dim), mlp_norm=RMSNorm(dim=embed_dim), ca_scale=TanhGate(), mlp_scale=TanhGate(), ) fusion_layer = FusionLayer(layer=decoder_layer, fusion_layer=xattn_layer) layers.append(fusion_layer) else: layers.append(decoder_layer) tok_embeddings = FusionEmbedding(vocab_size, num_special_tokens, embed_dim) # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( adapter_cls( embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, ) if apply_lora_to_output and decoder_lora else nn.Linear(embed_dim, vocab_size, bias=False) ) model = TransformerDecoder( tok_embeddings=tok_embeddings, layers=layers, max_seq_len=max_seq_len, num_heads=num_heads, head_dim=head_dim, norm=RMSNorm(embed_dim, eps=1e-05), output=output_proj, ) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly # so as to not increase peak memory model._register_state_dict_hook( partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) ) return model
def lora_llama3_2_vision_projection_head( lora_modules: List[LORA_ATTN_MODULES], *, # projection head parameters num_layers: int, num_heads: int, decoder_embed_dim: int, clip_embed_dim: int, num_hidden_inputs: int, # LoRA args apply_lora_to_mlp: bool, apply_lora_to_output: bool, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, **quantization_kwargs, ) -> Llama3VisionProjectionHead: """ Build the Llama 3.2 Vision Projection Head with LoRA applied to a subset of the layers. Args: lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. num_layers (int): number of layers in the projection head. num_heads (int): number of heads in the projection head. decoder_embed_dim (int): embedding dimension for the decoder. clip_embed_dim (int): embedding dimension for the CLIP encoder. num_hidden_inputs (int): number of hidden inputs to the projection head. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. quantize_base (bool): Whether to quantize base model parameters for linear layers LoRA is being applied to. Default is ``False``. Returns: Llama3VisionProjectionHead: Instantiation of Llama 3.2 vision projection head. """ mlp_ratio = 4 hidden_dim = int(mlp_ratio * clip_embed_dim) head_dim = clip_embed_dim // num_heads num_kv_heads = num_heads layers = [] for _ in range(num_layers): self_attn = lora_clip_attention( lora_modules=lora_modules, embed_dim=clip_embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, attn_dropout=0.0, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, **quantization_kwargs, ) if apply_lora_to_mlp: mlp = lora_clip_mlp( in_dim=clip_embed_dim, hidden_dim=hidden_dim, out_dim=clip_embed_dim, activation=nn.GELU(), lora_rank=lora_rank, lora_alpha=lora_alpha, quantize_base=quantize_base, lora_dropout=lora_dropout, use_dora=use_dora, **quantization_kwargs, ) else: mlp = clip_mlp( in_dim=clip_embed_dim, hidden_dim=hidden_dim, out_dim=clip_embed_dim, activation=nn.GELU(), quantize_base=quantize_base, **quantization_kwargs, ) layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), mlp_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), sa_scale=TanhGate(), mlp_scale=TanhGate(), ) layers.append(layer) # we concatenate clip embeddings and hidden layers output # and project it to embed_dim_out, which will be used for the # cross encoding # TODO: quantize_base is not applied to final output_proj currently. proj_in = clip_embed_dim * (num_hidden_inputs + 1) adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( adapter_cls( proj_in, decoder_embed_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, use_bias=True, ) if apply_lora_to_output else nn.Linear(proj_in, decoder_embed_dim) ) return Llama3VisionProjectionHead( layers=layers, output=output_proj, num_hidden_inputs=num_hidden_inputs, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources