Shortcuts

Multimodal Datasets

Multimodal datasets include more than one data modality, e.g. text + image, and can be used to train transformer-based models. torchtune currently only supports multimodal text+image chat datasets for Vision-Language Models (VLMs).

The primary entry point for fine-tuning with multimodal datasets in torchtune is the multimodal_chat_dataset() builder. This lets you specify a local or Hugging Face dataset that follows the multimodal chat data format directly from the config and train your VLM on it.

Example multimodal dataset

Here is an example of a multimodal chat dataset for a visual question-answering task. Note that there is a placeholder in the text, "<image>" for where to place the image tokens. This will get replaced by the image special token <|image|> in the example below.

# data/my_data.json
[
    {
        "dialogue": [
            {
                "from": "human",
                "value": "<image>What time is it on the clock?",
            },
            {
                "from": "gpt",
                "value": "It is 10:00 AM.",
            },
        ],
        "image_path": "images/clock.jpg",
    },
    ...,
]
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset

model_transform = Llama3VisionTransform(
    path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
    prompt_template="torchtune.data.QuestionAnswerTemplate",
    max_seq_len=8192,
    image_size=560,
)
ds = multimodal_chat_dataset(
    model_transform=model_transform,
    source="json",
    data_files="data/my_data.json",
    column_map={
        "dialogue": "conversations",
        "image_path": "image",
    },
    image_dir="/home/user/dataset/",  # /home/user/dataset/images/clock.jpg
    image_tag="<image>",
    split="train",
)
tokenized_dict = ds[0]
print(model_transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nQuestion:<|image|>What time is it on the clock?Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nIt is 10:00AM.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape)  # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
# In config - model_transforms takes the place of the tokenizer
model_transform:
  _component_: torchtune.models.llama3_2_vision_transform
  path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
  prompt_template: torchtune.data.QuestionAnswerTemplate
  max_seq_len: 8192

dataset:
  _component_: torchtune.datasets.multimodal.multimodal_chat_dataset
  source: json
  data_files: data/my_data.json
  split: train
  column_map:
    dialogue: conversations
    image_path: image
  image_dir: /home/user/dataset/
  image_tag: "<image>"
  split: train

Multimodal dataset format

Multimodal datasets are currently expected to follow the "sharegpt" chat format, where the image paths are in one column and the user-assistant conversations are in another column.

|  conversations                     | image        |
|------------------------------------|--------------|
| [{"from": "human", "value": "Q1"}, | images/1.jpg |
|  {"from": "gpt", "value": "A1"}]   |              |

As an example, you can see the schema of the ShareGPT4V dataset.

Currently, multimodal_chat_dataset() only supports a single image path per conversation sample.

Loading multimodal datasets from Hugging Face

You simply need to pass in the dataset repo name to source, which is then passed into Hugging Face’s load_dataset. For most datasets, you will also need to specify the split and/or the subset via name.

# In code
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset

model_transform = llama3_2_vision_transform(
    path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
    max_seq_len=8192,
    image_size=560,
)
ds = multimodal_chat_dataset(
    model_transform=model_transform,
    source="Lin-Chen/ShareGPT4V",
    split="train",
    name="ShareGPT4V",
    image_dir="/home/user/dataset/",
    image_tag="<image>",
)
# In config
model_transform:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
  max_seq_len: 8192
  image_size: 560

# Tokenizer is passed into the dataset in the recipe
dataset:
  _component_: torchtune.datasets.multimodal.multimodal_chat_dataset
  source: Lin-Chen/ShareGPT4V
  split: train
  name: ShareGPT4V
  image_dir: /home/user/dataset/
  image_tag: "<image>"

This will use the default column names “conversations” and “image”. To change the column names, use the column_map argument (see Renaming columns).

Loading local and remote multimodal datasets

To load in a local or remote dataset via https that follows the instruct format, you need to specify the source, data_files and split arguments. See Hugging Face’s load_dataset documentation for more details on loading local or remote files. See Example multimodal dataset above.

Loading images

In many cases, your dataset will contain paths to the images instead of the raw images themselves. multimodal_chat_dataset() will automatically handle this for you, but if you are writing a custom message transform for a custom multimodal dataset (see Custom message transforms), you can use the load_image() utility directly.

from torchtune.data import load_image
from pathlib import Path

sample = {
    "conversations": [
        {
            "from": "human",
            "value": "What time is it on the clock?",
        },
        {
            "from": "gpt",
            "value": "It is 10:00 AM.",
        },
    ],
    "image": "images/clock.jpg",
}
image_dir = "/home/user/dataset/"
pil_image = load_image(Path(image_dir) / Path(sample["image"]))
print(pil_image)
# <PIL.Image.Image>

Then, you can add the PIL image directly to the content of the related message. Only PIL images are supported as image content in Message, not image paths or urls.

from torchtune.data import Message

user_message = None
for msg in sample["conversations"]:
    if msg["from"] == "human":
        user_message = Message(
            role="user",
            content=[
                {"type": "image", "content": pil_image},
                {"type": "text", "content": msg["value"]},
            ]
        )
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>]
print(user_message.text_content)
# What time is it on the clock?

If the image paths in your dataset are relative paths, you can use the image_dir parameter in multimodal_chat_dataset() to prepend the full path where your images are downloaded locally.

Interleaving images in text

torchtune supports adding multiple images in any locations in the text, as long as your model supports it.

import PIL
from torchtune.data import Message

image_dog = PIL.Image.new(mode="RGB", size=(4, 4))
image_cat = PIL.Image.new(mode="RGB", size=(4, 4))
image_bird = PIL.Image.new(mode="RGB", size=(4, 4))

user_message = Message(
    role="user",
    content=[
        {"type": "image", "content": image_dog},
        {"type": "text", "content": "This is an image of a dog. "},
        {"type": "image", "content": image_cat},
        {"type": "text", "content": "This is an image of a cat. "},
        {"type": "image", "content": image_bird},
        {"type": "text", "content": "This is a bird, the best pet of the three."},
    ]
)
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>, <PIL.Image.Image>, <PIL.Image.Image>]
print(user_message.text_content)
# This is an image of a dog. This is an image of a cat. This is a bird, the best pet of the three.

Your dataset may contain image placeholder tags which indicate where in the text the image should be referenced As an example, see ShareGPT4V <https://huggingface.co/datasets/Lin-Chen/ShareGPT4V>, which uses "<image>". You can easily create the interleaved message content similar to above with the utility format_content_with_images(), which replaces the image placeholder tags with the passed in images.

import PIL
from torchtune.data import Message, format_content_with_images

image_dog = PIL.Image.new(mode="RGB", size=(4, 4))
image_cat = PIL.Image.new(mode="RGB", size=(4, 4))
image_bird = PIL.Image.new(mode="RGB", size=(4, 4))

text = "[img]This is an image of a dog. [img]This is an image of a cat. [img]This is a bird, the best pet of the three."
user_message = Message(
    role="user",
    content=format_content_with_images(
        content=text,
        image_tag="[img]",
        images=[image_dog, image_cat, image_bird],
    ),
)
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>,<PIL.Image.Image>, <PIL.Image.Image>]
print(user_message.text_content)
# This is an image of a dog. This is an image of a cat. This is a bird, the best pet of the three.

This is handled automatically for you in multimodal_chat_dataset() when you pass in image_tag.

Built-in multimodal datasets

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources