Multimodal Transforms¶
Multimodal model transforms apply model-specific data transforms to each modality and prepares Message
objects to be input into the model. torchtune currently supports text + image model transforms.
These are intended to be drop-in replacements for tokenizers in multimodal datasets and support the standard
encode
, decode
, and tokenize_messages
.
# torchtune.models.llama3_2_vision.Llama3VisionTransform
class Llama3VisionTransform(ModelTokenizer, Transform):
def __init__(...):
# Text transform - standard tokenization
self.tokenizer = llama3_tokenizer(...)
# Image transforms
self.transform_image = CLIPImageTransform(...)
self.xattn_mask = VisionCrossAttentionMask(...)
from torchtune.models.llama3_2_vision import Llama3VisionTransform
from torchtune.data import Message
from PIL import Image
sample = {
"messages": [
Message(
role="user",
content=[
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "text", "content": "What is common in these two images?"},
],
),
Message(
role="assistant",
content="A robot is in both images.",
),
],
}
transform = Llama3VisionTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
tile_size=224,
patch_size=14,
)
tokenized_dict = transform(sample)
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
Using model transforms¶
You can pass them into any multimodal dataset builder just as you would a model tokenizer.
from torchtune.datasets.multimodal import the_cauldron_dataset
from torchtune.models.llama3_2_vision import Llama3VisionTransform
transform = Llama3VisionTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
tile_size=224,
patch_size=14,
)
ds = the_cauldron_dataset(
model_transform=transform,
subset="ai2d",
)
tokenized_dict = ds[0]
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# <|begin_of_text|><|start_header_id|>user<|end_header_id|>
#
# <|image|>Question: What do respiration and combustion give out
# Choices:
# A. Oxygen
# B. Carbon dioxide
# C. Nitrogen
# D. Heat
# Answer with the letter.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
#
# Answer: B<|eot_id|>
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
Creating model transforms¶
Model transforms are expected to process both text and images in the sample dictionary.
Both should be contained in the "messages"
field of the sample.
The following methods are required on the model transform:
tokenize_messages
__call__
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
class MyMultimodalTransform(ModelTokenizer, Transform):
def __init__(...):
self.tokenizer = my_tokenizer_builder(...)
self.transform_image = MyImageTransform(...)
def tokenize_messages(
self,
messages: List[Message],
add_eos: bool = True,
) -> Tuple[List[int], List[bool]]:
# Any other custom logic here
...
return self.tokenizer.tokenize_messages(
messages=messages,
add_eos=add_eos,
)
def __call__(
self, sample: Mapping[str, Any], inference: bool = False
) -> Mapping[str, Any]:
# Expected input parameters for vision encoder
encoder_input = {"images": [], "aspect_ratio": []}
messages = sample["messages"]
# Transform all images in sample
for message in messages:
for image in message.get_media():
out = self.transform_image({"image": image}, inference=inference)
encoder_input["images"].append(out["image"])
encoder_input["aspect_ratio"].append(out["aspect_ratio"])
sample["encoder_input"] = encoder_input
# Transform all text - returns same dictionary with additional keys "tokens" and "mask"
sample = self.tokenizer(sample, inference=inference)
return sample
transform = MyMultimodalTransform(...)
sample = {
"messages": [
Message(
role="user",
content=[
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "text", "content": "What is common in these two images?"},
],
),
Message(
role="assistant",
content="A robot is in both images.",
),
],
}
tokenized_dict = transform(sample)
print(tokenized_dict)
# {'encoder_input': {'images': ..., 'aspect_ratio': ...}, 'tokens': ..., 'mask': ...}
Example model transforms¶
- Llama 3.2 Vision