Shortcuts

VisionCrossAttentionMask

class torchtune.modules.transforms.VisionCrossAttentionMask(tile_size: int, patch_size: int, image_token_id: int, max_num_tiles: Optional[int] = None)[source]

Computes the cross-attention mask for text + image inputs. Text tokens that participate in cross-attention with an image token will show True in the mask and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper (https://arxiv.org/pdf/2204.14198):

  1. Text tokens immediately following the image token up until the next image token

  2. Consecutive image tokens attend to subsequent text tokens

     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img2 │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
     ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img3 │   │ │   │ │   │ │   │ │   │ │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
     └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
    <img1> <img2>These  are   two  dogs. <img3> This   is    a    cat.

Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), where True indicates that the token outputted from the image encoder attends to the token in the text sequence in cross-attention. A list of these masks are returned with length equal to number of images in the sample.

Parameters:
  • tile_size (int) – The size of the image tiles from the image transform

  • patch_size (int) – The size of each patch. Used to divide the tiles into patches. E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each.

  • image_token_id (int) – Token ID of the image special token.

  • max_num_tiles (Optional[int]) – Maximum number of tiles in an image, used to pad mask during inference. Defaults to None

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources