padded_collate_tiled_images_and_mask¶
- torchtune.data.padded_collate_tiled_images_and_mask(batch: List[Dict[str, Any]], padding_idx: int = 0, ignore_idx: int = - 100, pad_direction: str = 'right', pad_max_tiles: Optional[int] = None, pad_max_images: Optional[int] = None) Dict[str, Tensor] [source]¶
Pad a batch of text sequences, tiled image tensors, aspect ratios, and cross attention masks. This can be used for both training and inference.
batch
is expected to be a list of sample dicts containing the following::“tokens”: List[int] of length text_seq_len, varies across samples
“labels”: List[int] of length text_seq_len, varies across samples
- “encoder_input”: Dict[str, List[torch.Tensor]]
“images”: List[torch.Tensor], each with shape (n_tiles, c, h, w)
“aspect_ratio”: List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio
“encoder_mask”: List[Tensor], each with shape (text_seq_len, image_seq_len)
- Shape notation:
c = channel dim
h = height dim
w = weight dim
Note
For each element in the batch,
len(images) == len(encoder_mask) == len(aspect_ratio)
.- This collater does the following:
Pad text sequence and encoder mask to the longest sequence length in the batch
Pad image tensors in the tile dimension with zeros to the largest number of tiles in the batch
Add empty images of zeros to samples up to max number of images in the batch
Pad aspect ratios with (1,1) for all added padding images
- Parameters:
batch (List[Dict[str, Any]]) – A list of sample dicts containing tokens, labels, images, encoder_mask, and aspect_ratio.
padding_idx (int) – Padding index for input token ids. Defaults to 0.
ignore_idx (int) – Padding index for labels. Defaults to -100.
pad_direction (str) – whether to pad entries from the left, or right. If
pad_direction="right"
, we usetorch.nn.utils.rnn.pad_sequence()
, otherwise ifpad_direction="left"
, we usetorchtune.data.left_pad_sequence()
. For training, we typically want to pad from the right. For inference, we typically want to pad from the left. Defaults to “right”.pad_max_tiles (Optional[int]) – Maximum number of tiles to pad to. If None, will pad to the largest number of tiles in the batch. Defaults to None.
pad_max_images (Optional[int]) – Maximum number of images to pad to. If None, will pad to the largest number of images in the batch. Defaults to None.
- Returns:
- Collated tokens, labels, images, encoder_mask, aspect_ratio tensors.
tokens: Tensor of shape (bsz, max_seq_len)
labels: Tensor of shape (bsz, max_seq_len)
images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)
encoder_mask: Tensor of shape (bsz, max_seq_len, tokens_per_tile * max_num_tiles * max_num_images)
aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
- Return type:
Dict[str, Tensor]
- Raises:
ValueError – if
pad_direction
is not one of “left” or “right”.ValueError – if pad_max_tiles is set to a value less than the largest number of tiles in an image.
Example
>>> image_id = 1 >>> tokens_per_tile = 5 >>> c, h, w = 1, 1, 1 >>> batch = [ ... { ... "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7], ... "encoder_input": { ... # One image with two tiles, one image with three tiles ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], ... }, ... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles) ... "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)], ... }, ... { ... "tokens": [1, 4], "labels": [8, 9], ... "encoder_input": { ... # One image with four tiles ... "images": [torch.ones(4, c, h, w)], ... "aspect_ratio": [torch.tensor([2, 2])], ... }, ... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles) ... "encoder_mask": [torch.ones(2, 5 * 4)], ... }, ... ] >>> model_inputs = padded_collate_tiled_images_and_mask(batch=batch) >>> print(model_inputs["tokens"]) tensor([[1, 2, 1, 3], [1, 4, 0, 0]]) >>> print(model_inputs["labels"]) tensor([[4, 5, 6, 7], [8, 9, -100, -100]]) >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) torch.Size([2, 2, 4, 1, 1, 1]) >>> print(model_inputs["encoder_mask"].shape) # (bsz, max_text_seq_len, tokens_per_tile * max_num_tiles * max_num_images) torch.Size([2, 4, 40]) >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) torch.Size([2, 2, 2]) >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])