VisionCrossAttentionMask¶
- class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int)[source]¶
Computes the cross-attention mask for text + image inputs. Text tokens that participate in cross-attention with an image token will show True in the mask and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper (https://arxiv.org/pdf/2204.14198):
Text tokens immediately following the image token up until the next image token
Consecutive image tokens attend to subsequent text tokens
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ <img1> <img2>These are two dogs. <img3> This is a cat.
Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), where True indicates that the token outputted from the image encoder attends to the token in the text sequence in cross-attention. A list of these masks are returned with length equal to number of images in the sample.
- Parameters:
tile_size (int) – The size of the image tiles from the image transform
patch_size (int) – The size of each patch. Used to divide the tiles into patches. E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each.
image_token_id (int) – Token ID of the image special token.