Shortcuts

VisionTransformer

class torchtune.modules.VisionTransformer(patch_size: int, tile_size: int, num_layers: int, embed_dim: int, layer: Module, token_pos_embedding: Module, pre_tile_pos_embed: Optional[Module] = None, post_tile_pos_embed: Optional[Module] = None, cls_projection: Optional[Module] = None, out_indices: Optional[List[int]] = None, in_channels: int = 3)[source]

Implementation of the ViT architecture (https://arxiv.org/abs/2010.11929), with support for tile-cropped images, outputting of hidden layers and optional CLS projection.

ViT is a transformer architecture that takes in images and outputs N embedded tokens that represent this image. Each image is divided into patches by a convolution. These patches are flattened and subsequently treated as tokens by the transformer.

To further enhance the performance of ViT and avoid downscaling images, we support tile-cropped images, which are images divided into tiles during the preprocessing stage. For example, instead of downscaling an 800x400 image to fit 400x400, we may crop it into two 400x400 tiles, if the tile_size=400. For details on preprocessing, please refer to torchtune.models.clip._transforms.CLIPImageTransform.

Each of these tiles is further broken down into patches by a convolution operation. For example, if your patch_size=40, then each (400, 400) tile will become a grid of 10x10 patches, and your whole image will have num_tiles * n_tokens -> num_tiles * (10x10 patches + 1 CLS token) -> num_tiles * 101.

Before the transformer layers, a CLS token is added to each tile as the first token. In transformers, a token called CLS is a special token that is added to the beginning of each sequence. This token can be used to represent the whole input, instead of using a pooling operation, for example.

To help the model “see” the whole image, we use positional embeddings. If your image was tile-cropped, then you need to use tile positional embeddings:

Otherwise, pre and post tile_pos_embed should be None and all you need is a simple token positional embedding:

All images will be considered as a stack of tiles, even if your image was not tile-cropped. In such cases, your image would be composed of a single tile.

In summary:

  1. An image is broken down into tiles during preprocessing.

  2. In the ViT, the tiles will be broken down into patches.

  3. The patches will be flattened and transformed. We call them tokens, because that’s how the transformer sees them.

Image: shape (8x8)

|  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |
|  9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
| 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
| 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 |
| 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |
| 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
| 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 |
| 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |

Tiles: shape (4,4,4) # (num_tiles, tile_size, tile_size)

|  1 |  2 |  3 |  4 |    |  5 |  6 |  7 |  8 |
|  9 | 10 | 11 | 12 |    | 13 | 14 | 15 | 16 |
| 17 | 18 | 19 | 20 |    | 21 | 22 | 23 | 24 |
| 25 | 26 | 27 | 28 |    | 29 | 30 | 31 | 32 |

| 33 | 34 | 35 | 36 |    | 37 | 38 | 39 | 40 |
| 41 | 42 | 43 | 44 |    | 45 | 46 | 47 | 48 |
| 49 | 50 | 51 | 52 |    | 53 | 54 | 55 | 56 |
| 57 | 58 | 59 | 60 |    | 61 | 62 | 63 | 64 |

Patches: shape (4,4,2,2) # (num_tiles, num_patches_per_tile, patch_size, patch_size)

|  1 |  2 |    |  3 |  4 |    |  5 |  6 |    |  7 |  8 |
|  9 | 10 |    | 11 | 12 |    | 13 | 14 |    | 15 | 16 |

| 17 | 18 |    | 19 | 20 |    | 21 | 22 |    | 23 | 24 |
| 25 | 26 |    | 27 | 28 |    | 29 | 30 |    | 31 | 32 |

| 33 | 34 |    | 35 | 36 |    | 37 | 38 |    | 39 | 40 |
| 41 | 42 |    | 43 | 44 |    | 45 | 46 |    | 47 | 48 |

| 49 | 50 |    | 51 | 52 |    | 53 | 54 |    | 55 | 56 |
| 57 | 58 |    | 59 | 60 |    | 61 | 62 |    | 63 | 64 |

token: shape (4, 4, 4) # (num_tiles, num_patches_per_tile, emb_dim)

|  1 |  2 |  9 |  10 |    |  3 |  4 |  11 |  12 |    |  17 |  18 |  25 |  26 |    | 19 | 20 |  27 |  28 |
| ... continuation of data ...
| ... continuation of data ...
| 37 | 38 | 45 |  46 |    | 39 |  40 | 47 |  48 |    | 53 | 54 |  61 |  62 |    | 55 | 56 |  63 |  64 |

For the positional embeddings:

Same for every tile, different for every token.

|  1 |  2 |  3 |  4 |    |  1 |  2 |  3 |  4 |
|  9 | 10 | 11 | 12 |    |  9 | 10 | 11 | 12 |
| 17 | 18 | 19 | 20 |    | 17 | 18 | 19 | 20 |
| 25 | 26 | 27 | 28 |    | 25 | 26 | 27 | 28 |

|  1 |  2 |  3 |  4 |    |  1 |  2 |  3 |  4 |
|  9 | 10 | 11 | 12 |    |  9 | 10 | 11 | 12 |
| 17 | 18 | 19 | 20 |    | 17 | 18 | 19 | 20 |
| 25 | 26 | 27 | 28 |    | 25 | 26 | 27 | 28 |

Different for every tile, different for every token.

|  1 |  2 |    |  3 |  4 |    |  5 |  6 |    |  7 |  8 |
|  9 | 10 |    | 11 | 12 |    | 13 | 14 |    | 15 | 16 |

| 17 | 18 |    | 19 | 20 |    | 21 | 22 |    | 23 | 24 |
| 25 | 26 |    | 27 | 28 |    | 29 | 30 |    | 31 | 32 |

| 33 | 34 |    | 35 | 36 |    | 37 | 38 |    | 39 | 40 |
| 41 | 42 |    | 43 | 44 |    | 45 | 46 |    | 47 | 48 |

| 49 | 50 |    | 51 | 52 |    | 53 | 54 |    | 55 | 56 |
| 57 | 58 |    | 59 | 60 |    | 61 | 62 |    | 63 | 64 |

different for every tile, same for every token within a tile.

|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |
|  1 |  1 |  1 |  1 |    |  2 |  2 |  2 |  3 |

|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
|  3 |  3 |  3 |  3 |    |  4 |  4 |  4 |  4 |
Parameters:
  • patch_size (int) – The size of each patch. Used to divide the tiles into patches. E.g. for patch_size=40, a tile of shape (400, 400) will have 10x10 grid of patches.

  • tile_size (int) – The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the input image. In this case, the function will consider your image as a single tile. with shape (40, 40) each.

  • num_layers (int) – The number of transformer layers.

  • embed_dim (int) – The dimensionality of each patch embedding (token).

  • layer (nn.Module) – The transformer layer module.

  • token_pos_embedding (nn.Module) – The token positional embedding module.

  • pre_tile_pos_embed (Optional[nn.Module]) – The pre-tile positional embedding module. It should be None if your image was not tile-cropped in advance.

  • post_tile_pos_embed (Optional[nn.Module]) – The post-tile positional embedding module. It should be None if your image was not tile-cropped in advance.

  • cls_projection (Optional[nn.Module]) – The CLS projection module. It should take an input tensor of shape (bsz * n_tiles, n_tokens, embed_dim) and output a tensor of shape (bsz * n_tiles, cls_output_dim). If provided, only the CLS token projection will be outputted, instead of all tokens.

  • out_indices (Optional[List[int]]) – The indices of hidden layers to return. If provided, it will return the intermediate results of the transformer layers before they go through a next layer. For example, out_indices=[0,3] will return the tokens before they go through the first and fourth layers.

  • in_channels (int) – The number of image input channels.

Raises:
  • ValueError – If tile_size is not greater than 0.

  • ValueError – If patch_size is not greater than 0.

  • ValueError – If len(out_indices) is greater than num_layers.

forward(images: Tensor, aspect_ratio: Optional[Tensor] = None) Tuple[Tensor, List[Tensor]][source]

Processes images and returns the tokens and hidden states.

Multiple images per sample: we add a dimension n_imgs to the input. This is useful when a single sample constains multiple images, for example:

  • sample 1: “<image> what animal is this?”

  • sample 2: “I like <image> more than <image>”

In this case, sample 1 has one image, and sample 2 has two images. max_n_imgs = max(2,1) = 2. So your input should have shape (bsz=2, n_imgs=2, num_tiles, n_channels, tile_size, tile_size).

Notice that to batch it, you will have to pad n_imgs to max_n_imgs and max_num_tiles.

Parameters:
  • images (torch.Tensor) – torch.Tensor with shape (bsz, n_imgs, n_tiles, n_channels, tile_size, tile_size).

  • aspect_ratio (Optional[torch.Tensor]) – torch.Tensor with shape (bsz, n_imgs, 2). If all images have a single tile, i.e. they were not tile-cropped, it should be None. Used to calculate the positional embeddings for the tiles.

Returns:

A tuple: (x, hidden_states),

where x is a torch.tensor of shape (bsz, n_imgs, n_tiles, n_tokens, embed_dim) and hidden_states has shape is a list of len(out_indices) torch.tensor with shape (bsz, n_imgs, n_tiles, n_tokens, embed_dim).

Return type:

Tuple[torch.Tensor, List[torch.Tensor]]

Raises:

ValueError – If aspect_ratio is None, but n_tiles > 1 in the batch.

Examples

>>> from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop
>>> from torchtune.modules import VisionTransformer
>>>
>>> num_channels = 3
>>> image_size = (800,400)
>>> tile_size = 400
>>> patch_size=40
>>> patch_grid_size = tile_size // patch_size
>>>
>>> # for details about preprocessing, please check
>>> # torchtune.models.clip._transforms.CLIPImageTransform
>>>
>>> # create a random image
>>> image = torch.rand(num_channels, image_size[0], image_size[1])
>>>
>>> # (num_tiles, nch, h, w) -> (2, 3, 400, 400)
>>> tile_cropped_image = tile_crop(image, tile_size)
>>> aspect_ratio = torch.tensor([2,1])
>>>
>>> # make it a batch of 1 image
>>> batch_image = tile_cropped_image.unsqueeze(0)
>>> batch_aspect_ratio = aspect_ratio.unsqueeze(0)
>>>
>>> # make it have only 1 image per sample
>>> batch_image = tile_cropped_image.unsqueeze(1)
>>> batch_aspect_ratio = aspect_ratio.unsqueeze(1)
>>>
>>> # For a detailed example, please check
>>> # torchtune.models.clip._position_embeddings.clip_vision_encoder
>>> # model = VisionTransformer(
... #           out_indices = [1,2,3,4,5],
... #           patch_size=40,
... #           patch_grid_size = patch_grid_size,
... #           embed_dim = 32,
... #           num_layers = 6,
... #           in_channels = num_channels,
... #           ...)
>>>
>>> x, hidden_states = model(images = batch_image, aspect_ratio = batch_aspect_ratio)
>>>
>>> # (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim)
>>> print(x.shape)
torch.Size([1, 1, 2, 101, 32])
>>>
>>> # list with tensors of shape (bsz, n_imgs, num_tiles, num_patches_per_tile + CLS token, embed_dim)
>>> print(len(hidden_states))
5

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources