Shortcuts

Source code for torchtune.models.clip._component_builders

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import Callable, List, Optional

from torch import nn

from torchtune.models.clip._position_embeddings import (
    TiledTokenPositionalEmbedding,
    TilePositionalEmbedding,
    TokenPositionalEmbedding,
)
from torchtune.models.clip._text_encoder import CLIPTextEncoder, QuickGELU
from torchtune.modules import (
    FeedForward,
    Fp32LayerNorm,
    FrozenNF4Linear,
    MultiHeadAttention,
    TransformerSelfAttentionLayer,
    VisionRotaryPositionalEmbeddings,
)
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer


[docs]def clip_vision_encoder( tile_size: int, patch_size: int, embed_dim: int, num_layers: int, num_heads: int, activation: Callable = nn.SiLU, cls_output_dim: int = 512, attn_bias: bool = True, use_rope: bool = False, out_indices: Optional[List[int]] = None, output_cls_projection: bool = False, max_num_tiles: int = 4, in_channels: int = 3, append_cls_token: bool = False, use_tile_pos_embed: bool = True, ) -> VisionTransformer: """ Builds the vision encoder associated with the clip model. This includes: - TransformerEncoderLayer - positional embeddings - CLS projection (optional) For details, please check the documentation of :class:`torchtune.modules.vision_transformer.VisionTransformer`. Args: tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the input image. In this case, the function will consider your image as a single tile. patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. embed_dim (int): The dimensionality of each patch embedding (token). num_layers (int): The number of transformer layers. num_heads (int): The number of attention heads in each transformer layer. activation (Callable): The activation function to use in the MLP layer. cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module. attn_bias (bool): Boolean for if to use bias in the attention module. Default True. use_rope (bool): If True, include 2D rope in attention in each transformer layer. Default: False out_indices (Optional[List[int]]): The indices of hidden layers to return. If provided, it will return the intermediate results of the transformer layers before they go through a next layer. For example, ``out_indices=[0,3]`` will return the tokens before they go through the first and fourth layers. output_cls_projection (bool): If True, only the CLS token projection will be outputted, instead of all tokens. Defaults to False. max_num_tiles (int): The maximum number of tiles that can be processed. This is used to determine the size of the positional embeddings. in_channels (int): The number of image input channels. append_cls_token (bool): If True, adds CLS token embedding to the end of the sequence in the vision transformer. Default is False, which adds CLS token to the beginning of the sequence. use_tile_pos_embed (bool): If True, use pre-tile, post-tile, and tiled token positional embeddings, if max_num_tiles > 1. If False, only use standard token positional embeddings. Returns: A `VisionTransformer` object. Raises: AssertionError: If ``embed_dim`` is not divisible by ``num_heads``. """ if embed_dim % num_heads != 0: raise ValueError( f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}" ) head_dim = embed_dim // num_heads cls_projection = ( CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None ) rope = ( VisionRotaryPositionalEmbeddings( patch_size=patch_size, tile_size=tile_size, dim=head_dim, base=10_000, append_cls_token=append_cls_token, ) if use_rope else None ) # transformer layer self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), pos_embeddings=rope, attn_dropout=0.0, is_causal=False, ) mlp = clip_mlp( in_dim=embed_dim, hidden_dim=4 * embed_dim, out_dim=embed_dim, activation=activation(), ) transformer_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5), mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5), sa_scale=None, mlp_scale=None, ) # position embeddings if use_tile_pos_embed and max_num_tiles > 1: pre_tile_pos_embed = TilePositionalEmbedding( max_num_tiles=max_num_tiles, embed_dim=embed_dim ) post_tile_pos_embed = TilePositionalEmbedding( max_num_tiles=max_num_tiles, embed_dim=embed_dim ) token_pos_embedding = TiledTokenPositionalEmbedding( max_num_tiles=max_num_tiles, embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size, ) else: pre_tile_pos_embed = None post_tile_pos_embed = None token_pos_embedding = TokenPositionalEmbedding( embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size ) return VisionTransformer( num_layers=num_layers, layer=transformer_layer, token_pos_embedding=token_pos_embedding, pre_tile_pos_embed=pre_tile_pos_embed, post_tile_pos_embed=post_tile_pos_embed, cls_projection=cls_projection, out_indices=out_indices, tile_size=tile_size, patch_size=patch_size, embed_dim=embed_dim, in_channels=in_channels, append_cls_token=append_cls_token, )
def clip_text_encoder( embed_dim: int, num_heads: int, num_layers: int, vocab_size: int = 49408, max_seq_len: int = 77, norm_eps: float = 1e-5, ): """ Text encoder for CLIP. CLIP is a model that encodes text and images into a shared vector space. Blog post: https://openai.com/index/clip/ Paper: https://arxiv.org/abs/2103.00020 Args: embed_dim (int): embedding/model dimension size num_heads (int): number of attention heads num_layers (int): number of transformer layers vocab_size (int): size of the vocabulary, default 49408 max_seq_len (int): context size, default 77 norm_eps (float): small value added to denominator for numerical stability, default 1e-5 Returns: CLIPTextEncoder """ attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_heads, head_dim=embed_dim // num_heads, q_proj=nn.Linear(embed_dim, embed_dim), k_proj=nn.Linear(embed_dim, embed_dim), v_proj=nn.Linear(embed_dim, embed_dim), output_proj=nn.Linear(embed_dim, embed_dim), ) mlp = clip_mlp( in_dim=embed_dim, out_dim=embed_dim, hidden_dim=embed_dim * 4, activation=QuickGELU(), ) encoder_layer = TransformerSelfAttentionLayer( attn=attn, mlp=mlp, sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps), mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps), ) final_norm = nn.LayerNorm(embed_dim, eps=norm_eps) return CLIPTextEncoder( layer=encoder_layer, final_norm=final_norm, vocab_size=vocab_size, max_seq_len=max_seq_len, embed_dim=embed_dim, num_layers=num_layers, ) def clip_mlp( in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, quantize_base: bool = False, **quantization_kwargs, ) -> FeedForward: """ Build the MLP layer associated with the clip model. """ gate_proj = ( nn.Linear(in_dim, hidden_dim) if not quantize_base else FrozenNF4Linear(in_dim, hidden_dim, bias=True, **quantization_kwargs) ) down_proj = ( nn.Linear(hidden_dim, out_dim) if not quantize_base else FrozenNF4Linear(hidden_dim, out_dim, bias=True, **quantization_kwargs) ) return FeedForward( gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation ) # ------------------ LoRA CLIP ------------------ def lora_clip_vision_encoder( lora_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, *, # clip encoder parameters tile_size: int, patch_size: int, embed_dim: int, num_layers: int, num_heads: int, activation: Callable = nn.SiLU, cls_output_dim: int = 512, attn_bias: bool = False, out_indices: Optional[List[int]] = None, output_cls_projection: bool = False, max_num_tiles: int = 4, in_channels: int = 3, # LoRA parameters lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, **quantization_kwargs, ) -> VisionTransformer: """ Build a LoRA implementation of the CLIP vision encoder. Args: lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to in each self-attention block. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the input image. In this case, the function will consider your image as a single tile. patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. embed_dim (int): The dimensionality of each patch embedding (token). num_layers (int): The number of transformer layers. num_heads (int): The number of attention heads in each transformer layer. activation (Callable): The activation function to use in the MLP layer. cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module. attn_bias (bool): Boolean for if to use bias in the attention module. Default False. out_indices (Optional[List[int]]): The indices of hidden layers to return. If provided, it will return the intermediate results of the transformer layers before they go through a next layer. For example, ``out_indices=[0,3]`` will return the tokens before they go through the first and fourth layers. output_cls_projection (bool): If True, only the CLS token projection will be outputted, instead of all tokens. Defaults to False. max_num_tiles (int): The maximum number of tiles that can be processed. This is used to determine the size of the positional embeddings. in_channels (int): The number of image input channels. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. Returns: VisionTransformer: Instantiation of VisionTransformer model. """ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" # TODO: add support for quantizing and LoRA for the final output projection cls_projection = ( CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None ) # transformer layer self_attn = lora_clip_attention( lora_modules=lora_modules, embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_heads, head_dim=embed_dim // num_heads, attn_dropout=0.0, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, attn_bias=attn_bias, **quantization_kwargs, ) if apply_lora_to_mlp: mlp = lora_clip_mlp( in_dim=embed_dim, hidden_dim=4 * embed_dim, out_dim=embed_dim, activation=activation(), lora_rank=lora_rank, lora_alpha=lora_alpha, quantize_base=quantize_base, lora_dropout=lora_dropout, use_dora=use_dora, **quantization_kwargs, ) else: mlp = clip_mlp( in_dim=embed_dim, hidden_dim=4 * embed_dim, out_dim=embed_dim, activation=activation(), quantize_base=quantize_base, **quantization_kwargs, ) transformer_layer = TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5), mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5), sa_scale=None, mlp_scale=None, ) # position embeddings if max_num_tiles == 1: pre_tile_pos_embed = None post_tile_pos_embed = None token_pos_embedding = TokenPositionalEmbedding( embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size ) else: pre_tile_pos_embed = TilePositionalEmbedding( max_num_tiles=max_num_tiles, embed_dim=embed_dim ) post_tile_pos_embed = TilePositionalEmbedding( max_num_tiles=max_num_tiles, embed_dim=embed_dim ) token_pos_embedding = TiledTokenPositionalEmbedding( max_num_tiles=max_num_tiles, embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size, ) model = VisionTransformer( num_layers=num_layers, layer=transformer_layer, token_pos_embedding=token_pos_embedding, pre_tile_pos_embed=pre_tile_pos_embed, post_tile_pos_embed=post_tile_pos_embed, cls_projection=cls_projection, out_indices=out_indices, tile_size=tile_size, patch_size=patch_size, embed_dim=embed_dim, in_channels=in_channels, ) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly # so as to not increase peak memory model._register_state_dict_hook( partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) ) return model def lora_clip_attention( lora_modules: List[LORA_ATTN_MODULES], *, # MultiHeadAttention args embed_dim: int, head_dim: int, num_heads: int, num_kv_heads: int, attn_dropout: float = 0.0, attn_bias: bool = False, # LoRA args lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, **quantization_kwargs, ) -> MultiHeadAttention: """ Return an instance of :func:`~torchtune.modules.MultiHeadAttention` with LoRA applied to a subset of its linear layers Args: lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. embed_dim (int): embedding dimension for self-attention head_dim (int): dimension of each head in the multihead attention. Usually computed as ``embed_dim // num_heads``. num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. quantize_base (bool): Whether to quantize base model parameters for linear layers LoRA is being applied to. Default is ``False``. Returns: MultiHeadAttention: instantiation of self-attention module with LoRA applied to a subset of Q, K, V, output projections. Raises: ValueError: If lora_modules arg is an empty list """ if not lora_modules: raise ValueError( f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" ) adapter_cls = DoRALinear if use_dora else LoRALinear q_proj = ( adapter_cls( embed_dim, num_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, **quantization_kwargs, ) if "q_proj" in lora_modules else ( nn.Linear(embed_dim, num_heads * head_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( embed_dim, num_heads * head_dim, bias=attn_bias, **quantization_kwargs ) ) ) k_proj = ( adapter_cls( embed_dim, num_kv_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, **quantization_kwargs, ) if "k_proj" in lora_modules else ( nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( embed_dim, num_kv_heads * head_dim, bias=attn_bias, **quantization_kwargs, ) ) ) v_proj = ( adapter_cls( embed_dim, num_kv_heads * head_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, **quantization_kwargs, ) if "v_proj" in lora_modules else ( nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( embed_dim, num_kv_heads * head_dim, bias=attn_bias, **quantization_kwargs, ) ) ) output_proj = ( adapter_cls( embed_dim, embed_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, **quantization_kwargs, ) if "output_proj" in lora_modules else ( nn.Linear(embed_dim, embed_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( embed_dim, embed_dim, bias=attn_bias, **quantization_kwargs ) ) ) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, output_proj=output_proj, pos_embeddings=None, attn_dropout=attn_dropout, ) return self_attn def lora_clip_mlp( *, in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, **quantization_kwargs, ) -> FeedForward: """ Build the MLP layer with LoRA applied to the gate and down projections. """ adapter_cls = DoRALinear if use_dora else LoRALinear gate_proj = adapter_cls( in_dim=in_dim, out_dim=hidden_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, use_bias=True, **quantization_kwargs, ) down_proj = adapter_cls( in_dim=hidden_dim, out_dim=out_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, use_bias=True, **quantization_kwargs, ) return FeedForward( gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources