Llama3VisionProjectionHead¶
- class torchtune.models.llama3_2_vision.Llama3VisionProjectionHead(layers: Module, output: Module, num_hidden_inputs: int = 0)[source]¶
Projection transformer to adapt the output of a pretrained frozen encoder (CLIP) to a pretrained decoder model. For example, nn.Sequential(CLIP(), Llama3VisionProjectionHead()).
- Parameters:
layers (nn.Module) – Transformer Decoder layers
output (nn.Module) – Output linear layer. Input dim is (num_hidden + 1) * encoder_dim and output is decoder_dim.
num_hidden_inputs (int) – Number of expected hidden state inputs
- forward(x: Tensor, hidden_states: Optional[List[Tensor]] = None) Tensor [source]¶
- Parameters:
x (torch.Tensor) – input tensor with shape [b x i x t x e x d]
hidden_states (Optional[List[torch.Tensor]]) – list of hidden states from the encoder. Each hidden state has the same shape as x.
- Returns:
- output tensor of a sequence of embedings [b x s x d]
where sequence length is num_imgs*num_tiles+num_embeds
- Return type:
Tensor
- Notation used for tensor shapes:
b: batch size
i: number of images
t: number of tiles (where a single image is broken into multiple tiles)
e: number of embeds per tile (e.g. CLS embed + patch embeds, etc.)
s: sequence length computed by i*t*e
d: embed dim