Source code for torchtune.models.clip._component_builders
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import Callable, List, Optional
from torch import nn
from torchtune.models.clip._position_embeddings import (
TiledTokenPositionalEmbedding,
TilePositionalEmbedding,
TokenPositionalEmbedding,
)
from torchtune.models.clip._text_encoder import CLIPTextEncoder, QuickGELU
from torchtune.modules import (
FeedForward,
Fp32LayerNorm,
FrozenNF4Linear,
MultiHeadAttention,
TransformerSelfAttentionLayer,
VisionRotaryPositionalEmbeddings,
)
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer
[docs]def clip_vision_encoder(
tile_size: int,
patch_size: int,
embed_dim: int,
num_layers: int,
num_heads: int,
activation: Callable = nn.SiLU,
cls_output_dim: int = 512,
attn_bias: bool = True,
use_rope: bool = False,
out_indices: Optional[List[int]] = None,
output_cls_projection: bool = False,
max_num_tiles: int = 4,
in_channels: int = 3,
append_cls_token: bool = False,
use_tile_pos_embed: bool = True,
) -> VisionTransformer:
"""
Builds the vision encoder associated with the clip model. This includes:
- TransformerEncoderLayer
- positional embeddings
- CLS projection (optional)
For details, please check the documentation of
:class:`torchtune.modules.vision_transformer.VisionTransformer`.
Args:
tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise,
the size of the input image. In this case, the function will consider your image as a single tile.
patch_size (int): The size of each patch. Used to divide the tiles into patches.
E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
embed_dim (int): The dimensionality of each patch embedding (token).
num_layers (int): The number of transformer layers.
num_heads (int): The number of attention heads in each transformer layer.
activation (Callable): The activation function to use in the MLP layer.
cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module.
attn_bias (bool): Boolean for if to use bias in the attention module. Default True.
use_rope (bool): If True, include 2D rope in attention in each transformer layer. Default: False
out_indices (Optional[List[int]]): The indices of hidden layers to return.
If provided, it will return the intermediate results of the transformer layers
before they go through a next layer. For example, ``out_indices=[0,3]`` will
return the tokens before they go through the first and fourth layers.
output_cls_projection (bool): If True, only the CLS token projection will be outputted,
instead of all tokens. Defaults to False.
max_num_tiles (int): The maximum number of tiles that can be processed. This is used to
determine the size of the positional embeddings.
in_channels (int): The number of image input channels.
append_cls_token (bool): If True, adds CLS token embedding to the end of the sequence in the vision transformer.
Default is False, which adds CLS token to the beginning of the sequence.
use_tile_pos_embed (bool): If True, use pre-tile, post-tile, and tiled token positional embeddings, if max_num_tiles > 1.
If False, only use standard token positional embeddings.
Returns:
A `VisionTransformer` object.
Raises:
AssertionError: If ``embed_dim`` is not divisible by ``num_heads``.
"""
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}"
)
head_dim = embed_dim // num_heads
cls_projection = (
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
if output_cls_projection
else None
)
rope = (
VisionRotaryPositionalEmbeddings(
patch_size=patch_size,
tile_size=tile_size,
dim=head_dim,
base=10_000,
append_cls_token=append_cls_token,
)
if use_rope
else None
)
# transformer layer
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
pos_embeddings=rope,
attn_dropout=0.0,
is_causal=False,
)
mlp = clip_mlp(
in_dim=embed_dim,
hidden_dim=4 * embed_dim,
out_dim=embed_dim,
activation=activation(),
)
transformer_layer = TransformerSelfAttentionLayer(
attn=self_attn,
mlp=mlp,
sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
sa_scale=None,
mlp_scale=None,
)
# position embeddings
if use_tile_pos_embed and max_num_tiles > 1:
pre_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
post_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
token_pos_embedding = TiledTokenPositionalEmbedding(
max_num_tiles=max_num_tiles,
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size,
)
else:
pre_tile_pos_embed = None
post_tile_pos_embed = None
token_pos_embedding = TokenPositionalEmbedding(
embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size
)
return VisionTransformer(
num_layers=num_layers,
layer=transformer_layer,
token_pos_embedding=token_pos_embedding,
pre_tile_pos_embed=pre_tile_pos_embed,
post_tile_pos_embed=post_tile_pos_embed,
cls_projection=cls_projection,
out_indices=out_indices,
tile_size=tile_size,
patch_size=patch_size,
embed_dim=embed_dim,
in_channels=in_channels,
append_cls_token=append_cls_token,
)
def clip_text_encoder(
embed_dim: int,
num_heads: int,
num_layers: int,
vocab_size: int = 49408,
max_seq_len: int = 77,
norm_eps: float = 1e-5,
):
"""
Text encoder for CLIP.
CLIP is a model that encodes text and images into a shared vector space.
Blog post: https://openai.com/index/clip/
Paper: https://arxiv.org/abs/2103.00020
Args:
embed_dim (int): embedding/model dimension size
num_heads (int): number of attention heads
num_layers (int): number of transformer layers
vocab_size (int): size of the vocabulary, default 49408
max_seq_len (int): context size, default 77
norm_eps (float): small value added to denominator for numerical stability, default 1e-5
Returns:
CLIPTextEncoder
"""
attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
q_proj=nn.Linear(embed_dim, embed_dim),
k_proj=nn.Linear(embed_dim, embed_dim),
v_proj=nn.Linear(embed_dim, embed_dim),
output_proj=nn.Linear(embed_dim, embed_dim),
)
mlp = clip_mlp(
in_dim=embed_dim,
out_dim=embed_dim,
hidden_dim=embed_dim * 4,
activation=QuickGELU(),
)
encoder_layer = TransformerSelfAttentionLayer(
attn=attn,
mlp=mlp,
sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
)
final_norm = nn.LayerNorm(embed_dim, eps=norm_eps)
return CLIPTextEncoder(
layer=encoder_layer,
final_norm=final_norm,
vocab_size=vocab_size,
max_seq_len=max_seq_len,
embed_dim=embed_dim,
num_layers=num_layers,
)
def clip_mlp(
in_dim: int,
out_dim: int,
hidden_dim: int,
activation: nn.Module,
quantize_base: bool = False,
**quantization_kwargs,
) -> FeedForward:
"""
Build the MLP layer associated with the clip model.
"""
gate_proj = (
nn.Linear(in_dim, hidden_dim)
if not quantize_base
else FrozenNF4Linear(in_dim, hidden_dim, bias=True, **quantization_kwargs)
)
down_proj = (
nn.Linear(hidden_dim, out_dim)
if not quantize_base
else FrozenNF4Linear(hidden_dim, out_dim, bias=True, **quantization_kwargs)
)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
)
# ------------------ LoRA CLIP ------------------
def lora_clip_vision_encoder(
lora_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
*,
# clip encoder parameters
tile_size: int,
patch_size: int,
embed_dim: int,
num_layers: int,
num_heads: int,
activation: Callable = nn.SiLU,
cls_output_dim: int = 512,
attn_bias: bool = False,
out_indices: Optional[List[int]] = None,
output_cls_projection: bool = False,
max_num_tiles: int = 4,
in_channels: int = 3,
# LoRA parameters
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> VisionTransformer:
"""
Build a LoRA implementation of the CLIP vision encoder.
Args:
lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise,
the size of the input image. In this case, the function will consider your image as a single tile.
patch_size (int): The size of each patch. Used to divide the tiles into patches.
E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
embed_dim (int): The dimensionality of each patch embedding (token).
num_layers (int): The number of transformer layers.
num_heads (int): The number of attention heads in each transformer layer.
activation (Callable): The activation function to use in the MLP layer.
cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module.
attn_bias (bool): Boolean for if to use bias in the attention module. Default False.
out_indices (Optional[List[int]]): The indices of hidden layers to return.
If provided, it will return the intermediate results of the transformer layers
before they go through a next layer. For example, ``out_indices=[0,3]`` will
return the tokens before they go through the first and fourth layers.
output_cls_projection (bool): If True, only the CLS token projection will be outputted,
instead of all tokens. Defaults to False.
max_num_tiles (int): The maximum number of tiles that can be processed. This is used to
determine the size of the positional embeddings.
in_channels (int): The number of image input channels.
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``.
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
weights within linear layers LoRA is applied to. The final output linear projection is not
supported for quantization currently.
Returns:
VisionTransformer: Instantiation of VisionTransformer model.
"""
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
# TODO: add support for quantizing and LoRA for the final output projection
cls_projection = (
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
if output_cls_projection
else None
)
# transformer layer
self_attn = lora_clip_attention(
lora_modules=lora_modules,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
attn_dropout=0.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
attn_bias=attn_bias,
**quantization_kwargs,
)
if apply_lora_to_mlp:
mlp = lora_clip_mlp(
in_dim=embed_dim,
hidden_dim=4 * embed_dim,
out_dim=embed_dim,
activation=activation(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
use_dora=use_dora,
**quantization_kwargs,
)
else:
mlp = clip_mlp(
in_dim=embed_dim,
hidden_dim=4 * embed_dim,
out_dim=embed_dim,
activation=activation(),
quantize_base=quantize_base,
**quantization_kwargs,
)
transformer_layer = TransformerSelfAttentionLayer(
attn=self_attn,
mlp=mlp,
sa_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
mlp_norm=Fp32LayerNorm(embed_dim, eps=1e-5),
sa_scale=None,
mlp_scale=None,
)
# position embeddings
if max_num_tiles == 1:
pre_tile_pos_embed = None
post_tile_pos_embed = None
token_pos_embedding = TokenPositionalEmbedding(
embed_dim=embed_dim, patch_size=patch_size, tile_size=tile_size
)
else:
pre_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
post_tile_pos_embed = TilePositionalEmbedding(
max_num_tiles=max_num_tiles, embed_dim=embed_dim
)
token_pos_embedding = TiledTokenPositionalEmbedding(
max_num_tiles=max_num_tiles,
embed_dim=embed_dim,
patch_size=patch_size,
tile_size=tile_size,
)
model = VisionTransformer(
num_layers=num_layers,
layer=transformer_layer,
token_pos_embedding=token_pos_embedding,
pre_tile_pos_embed=pre_tile_pos_embed,
post_tile_pos_embed=post_tile_pos_embed,
cls_projection=cls_projection,
out_indices=out_indices,
tile_size=tile_size,
patch_size=patch_size,
embed_dim=embed_dim,
in_channels=in_channels,
)
if quantize_base:
# For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly
# so as to not increase peak memory
model._register_state_dict_hook(
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True)
)
return model
def lora_clip_attention(
lora_modules: List[LORA_ATTN_MODULES],
*,
# MultiHeadAttention args
embed_dim: int,
head_dim: int,
num_heads: int,
num_kv_heads: int,
attn_dropout: float = 0.0,
attn_bias: bool = False,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> MultiHeadAttention:
"""
Return an instance of :func:`~torchtune.modules.MultiHeadAttention` with LoRA
applied to a subset of its linear layers
Args:
lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj",
"output_proj"}``.
embed_dim (int): embedding dimension for self-attention
head_dim (int): dimension of each head in the multihead attention. Usually
computed as ``embed_dim // num_heads``.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``.
quantize_base (bool): Whether to quantize base model parameters for linear layers
LoRA is being applied to. Default is ``False``.
Returns:
MultiHeadAttention: instantiation of self-attention module with LoRA
applied to a subset of Q, K, V, output projections.
Raises:
ValueError: If lora_modules arg is an empty list
"""
if not lora_modules:
raise ValueError(
f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules"
)
adapter_cls = DoRALinear if use_dora else LoRALinear
q_proj = (
adapter_cls(
embed_dim,
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "q_proj" in lora_modules
else (
nn.Linear(embed_dim, num_heads * head_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(
embed_dim, num_heads * head_dim, bias=attn_bias, **quantization_kwargs
)
)
)
k_proj = (
adapter_cls(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "k_proj" in lora_modules
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(
embed_dim,
num_kv_heads * head_dim,
bias=attn_bias,
**quantization_kwargs,
)
)
)
v_proj = (
adapter_cls(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "v_proj" in lora_modules
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(
embed_dim,
num_kv_heads * head_dim,
bias=attn_bias,
**quantization_kwargs,
)
)
)
output_proj = (
adapter_cls(
embed_dim,
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "output_proj" in lora_modules
else (
nn.Linear(embed_dim, embed_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(
embed_dim, embed_dim, bias=attn_bias, **quantization_kwargs
)
)
)
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
output_proj=output_proj,
pos_embeddings=None,
attn_dropout=attn_dropout,
)
return self_attn
def lora_clip_mlp(
*,
in_dim: int,
out_dim: int,
hidden_dim: int,
activation: nn.Module,
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> FeedForward:
"""
Build the MLP layer with LoRA applied to the gate and down projections.
"""
adapter_cls = DoRALinear if use_dora else LoRALinear
gate_proj = adapter_cls(
in_dim=in_dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
use_bias=True,
**quantization_kwargs,
)
down_proj = adapter_cls(
in_dim=hidden_dim,
out_dim=out_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
use_bias=True,
**quantization_kwargs,
)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
)