Shortcuts

Source code for torchtune.models.llama3_2_vision._transform

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate

from torchtune.models.clip import CLIPImageTransform
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform, VisionCrossAttentionMask


[docs]class Llama3VisionTransform(ModelTokenizer, Transform): """ This transform combines the transforms for the different modalities of Llama 3.2 Vision. It is made up of the following transforms: - :class:`torchtune.models.llama3.Llama3Tokenizer` - :class:`torchtune.models.clip.CLIPImageTransform` - :class:`torchtune.modules.transforms.VisionCrossAttentionMask` This transform can be used as a drop-in replacement for tokenizers in recipes and generation but handles additional transformations from the `__call__` method. Args: path (str): Path to pretrained tiktoken tokenizer file. tile_size (int): Size of the tiles to divide the image into. patch_size (int): Size of the patches used in the CLIP vision tranformer model. This is used to calculate the number of image embeddings per image. max_num_tiles (int): Only used if possible_resolutions is NOT given. Maximum number of tiles to break an image into. This will be used to generate possible_resolutions, e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. Default 4. special_tokens (Optional[Dict[str, int]]): mapping containing special text tokens and their registered token IDs. If left as None, this will be set to the canonical Llama3 special tokens. max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, after which the input will be truncated. Default is None. image_mean (Optional[Tuple[float, float, float]]): Mean values of each channel, used for normalization. image_std (Optional[Tuple[float, float, float]]): Standard deviations for each channel, used for normalization. prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used to add structured text around the actual messages. The structured text is used in three scenarios: - Task-specific templates to gear models for a particular task that it will expect after training - Model-specific templates that are required whenever the model is prompted, such as the [INST] tags in Llama2 and in Mistral - Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate` The extra text will still get tokenized as normal text, not as special tokens. Default is None. Examples: >>> model_transform = Llama3VisionTransform("/path/to/tokenizer.model", tile_size=224, patch_size=14) >>> transformed_data = model_transform({"messages": user_message, "images": [img1, img2]}) >>> print(transformed_data["tokens"]) [1, 31587, 29644, 102, 2] >>> print(transformed_data["images"][0].shape) torch.Size([4, 3, 224, 224]) """ def __init__( self, path: str, *, tile_size: int, patch_size: int, max_num_tiles: int = 4, special_tokens: Optional[Dict[str, int]] = None, max_seq_len: Optional[int] = None, image_mean: Optional[Tuple[float, float, float]] = None, image_std: Optional[Tuple[float, float, float]] = None, prompt_template: Optional[PromptTemplate] = None, ): self.tokenizer = llama3_tokenizer( path, special_tokens_path=special_tokens, max_seq_len=max_seq_len, prompt_template=prompt_template, ) self.transform_image = CLIPImageTransform( image_mean=image_mean, image_std=image_std, tile_size=tile_size, possible_resolutions=None, max_num_tiles=max_num_tiles, resample="bilinear", resize_to_max_canvas=False, ) self.xattn_mask = VisionCrossAttentionMask( tile_size=tile_size, patch_size=patch_size, image_token_id=self.tokenizer.image_id, ) self.stop_tokens = self.tokenizer.stop_tokens self.max_seq_len = max_seq_len self.max_num_tiles = max_num_tiles self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1) self.prompt_template = prompt_template self.pad_id = self.tokenizer.pad_id @property def base_vocab_size(self) -> int: return self.tokenizer.base_vocab_size @property def vocab_size(self) -> int: return self.tokenizer.vocab_size def encode( self, text: str, add_bos: bool = True, add_eos: bool = True, ) -> List[int]: return self.tokenizer.encode(text=text, add_bos=add_bos, add_eos=add_eos)
[docs] def decode( self, token_ids: List[int], truncate_at_eos: bool = True, skip_special_tokens: bool = True, ) -> str: """ Decode a list of token ids into a string. Args: token_ids (List[int]): The list of token ids. truncate_at_eos (bool): Whether to truncate the string at the end of sequence token. Default is True. skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string. Default is True. Returns: str: The decoded string. """ return self.tokenizer.decode( token_ids, truncate_at_eos=truncate_at_eos, skip_special_tokens=skip_special_tokens, )
[docs] def tokenize_message( self, message: Message, tokenize_header: bool = True, tokenize_end: bool = True, ) -> List[int]: """ Tokenize a message into a list of token ids. Args: message (Message): The message to tokenize. tokenize_header (bool): Whether to prepend a tokenized header to the message. tokenize_end (bool): Whether to append eot or eom id at the end of the message. Returns: List[int]: The list of token ids. """ return self.tokenizer.tokenize_message( message=message, tokenize_header=tokenize_header, tokenize_end=tokenize_end, )
[docs] def tokenize_messages( self, messages: List[Message], add_eos: bool = True, ) -> Tuple[List[int], List[bool]]: """ Tokenize a list of messages into a list of token ids and masks. Args: messages (List[Message]): The list of messages to tokenize. add_eos (bool): Wether to add the tokenizer's eos_id. Default True. Returns: Tuple[List[int], List[bool]]: The list of token ids and the list of masks. """ return self.tokenizer.tokenize_messages( messages=messages, add_eos=add_eos, )
def __call__( self, sample: Mapping[str, Any], inference: bool = False ) -> Mapping[str, Any]: """ Apply image decoding, transformations and tokenization to messages in the sample. Args: sample (Mapping[str, Any]): A sample with a "messages" field. inference (bool): Whether to run in inference mode. Default is True. Returns: Mapping[str, Any]: The transformed sample with the following fields: - tokens: List[int] of tokenized messages - mask: List[bool] of masks for the tokenized messages - encoder_input: Dict[str, Any] of transformed images - encoder_mask: List[bool] of masks for the transformed images """ encoder_input = {"images": [], "aspect_ratio": []} messages = sample["messages"] for message in messages: for image in message.get_media(): out = self.transform_image({"image": image}, inference=inference) encoder_input["images"].append(out["image"]) encoder_input["aspect_ratio"].append(out["aspect_ratio"]) sample["encoder_input"] = encoder_input sample = self.tokenizer(sample, inference=inference) sample = self.xattn_mask(sample, inference=inference) return sample

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources