Shortcuts

padded_collate_tiled_images_and_mask

torchtune.data.padded_collate_tiled_images_and_mask(batch: List[Dict[str, Any]], padding_idx: int = 0, ignore_idx: int = - 100, pad_direction: str = 'right') Dict[str, Tensor][source]

Pad a batch of text sequences, tiled image tensors, aspect ratios, and cross attention masks. This can be used for both training and inference.

batch is expected to be a list of sample dicts containing the following::
  • “tokens”: List[int] of length text_seq_len, varies across samples

  • “labels”: List[int] of length text_seq_len, varies across samples

  • “encoder_input”: Dict[str, List[torch.Tensor]]
    • “images”: List[torch.Tensor], each with shape (n_tiles, c, h, w)

    • “aspect_ratio”: List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio

  • “encoder_mask”: List[Tensor], each with shape (text_seq_len, image_seq_len)

Shape notation:
  • c = channel dim

  • h = height dim

  • w = weight dim

Note

For each element in the batch, len(images) == len(encoder_mask) == len(aspect_ratio).

This collater does the following:
  1. Pad text sequence and encoder mask to the longest sequence length in the batch

  2. Pad image tensors in the tile dimension with zeros to the largest number of tiles in the batch

  3. Add empty images of zeros to samples up to max number of images in the batch

  4. Pad aspect ratios with (1,1) for all added padding images

Parameters:
  • batch (List[Dict[str, Any]]) – A list of sample dicts containing tokens, labels, images, encoder_mask, and aspect_ratio.

  • padding_idx (int) – Padding index for input token ids. Defaults to 0.

  • ignore_idx (int) – Padding index for labels. Defaults to -100.

  • pad_direction (str) – whether to pad entries from the left, or right. If pad_direction="right", we use torch.nn.utils.rnn.pad_sequence(), otherwise if pad_direction="left", we use torchtune.data.left_pad_sequence(). For training, we typically want to pad from the right. For inference, we typically want to pad from the left. Defaults to “right”.

Returns:

Collated tokens, labels, images, encoder_mask, aspect_ratio tensors.
  • tokens: Tensor of shape (bsz, max_seq_len)

  • labels: Tensor of shape (bsz, max_seq_len)

  • images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)

  • encoder_mask: Tensor of shape (bsz, max_seq_len, tokens_per_tile * max_num_tiles * max_num_images)

  • aspect_ratio: Tensor of shape (bsz, max_num_images, 2)

Return type:

Dict[str, Tensor]

Raises:

ValueError – if pad_direction is not one of “left” or “right”.

Example

>>> image_id = 1
>>> tokens_per_tile = 5
>>> c, h, w = 1, 1, 1
>>> batch = [
...     {
...         "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
...         "encoder_input": {
...             # One image with two tiles, one image with three tiles
...             "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
...             "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
...         },
...         # Mask is shape (text_seq_len, tokens_per_tile * n_tiles)
...         "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
...     },
...     {
...         "tokens": [1, 4], "labels": [8, 9],
...         "encoder_input": {
...             # One image with four tiles
...             "images": [torch.ones(4, c, h, w)],
...             "aspect_ratio": [torch.tensor([2, 2])],
...         },
...         # Mask is shape (text_seq_len, tokens_per_tile * n_tiles)
...         "encoder_mask": [torch.ones(2, 5 * 4)],
...     },
... ]
>>> model_inputs = padded_collate_tiled_images_and_mask(batch=batch)
>>> print(model_inputs["tokens"])
tensor([[1, 2, 1, 3],
        [1, 4, 0, 0]])
>>> print(model_inputs["labels"])
tensor([[4, 5, 6, 7],
        [8, 9, -100, -100]])
>>> print(model_inputs["encoder_input"]["images"].shape)  # (bsz, max_num_images, max_num_tiles, c, h, w)
torch.Size([2, 2, 4, 1, 1, 1])
>>> print(model_inputs["encoder_mask"].shape)  # (bsz, max_text_seq_len, tokens_per_tile * max_num_tiles * max_num_images)
torch.Size([2, 4, 40])
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape)  # (bsz, max_num_images, 2)
torch.Size([2, 2, 2])
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...])  # Image with two tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...])  # Image with three tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...])  # Image with four tiles did not get padded
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...])  # Extra padding image was added to second sample
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources