Shortcuts

DeepFusionModel

class torchtune.modules.model_fusion.DeepFusionModel(decoder: TransformerDecoder, encoder: Module, *, decoder_trainable: bool = False, encoder_trainable: bool = False, fusion_trainable: bool = True)[source]

DeepFusion is a type of fused model architecture where a pretrained encoder is combined with a pretrained decoder (LLM). This is a popular architecture for multimodal models, with a full overview available in The Evolution of Multimodal Model Architectures.

This module has the same methods and forward signature as TransformerDecoder and can be used interchangeably where TransformerDecoder is. It combines the encoder with the decoder as a single module for checkpointing and finetuning. It is expected that the encoder and decoder are already defined with any extra learnable fusion_params: learnable parameters to help adapt the pre-trained encoder to the pre-trained decoder.

Example

>>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers
>>> embed = FusionEmbedding(...)
>>> layer = FusionLayer(
...     layer=TransformerSelfAttentionLayer(...),
...     fusion_layer=TransformerCrossAttentionLayer(...),
... )
>>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...)
>>>
>>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head
>>> projection_head = FeedForward(...)
>>> register_fusion_module(projection_head))
>>> encoder = nn.Sequential(clip_vit_224(), projection_head)
>>>
>>> # DeepFusionModel combines the encoder and decoder
>>> model = DeepFusionModel(decoder, encoder)
>>>
>>> # Load full fused checkpoints (e.g. a Flamingo checkpoint)
>>> model.load_state_dict(...)
>>>
>>> # Or load pretrained individual models (fusion_params are not loaded)
>>> model.encoder.load_state_dict(..., strict=False)
>>> model.decoder.load_state_dict(..., strict=False)
>>>
>>> # Forward pass
>>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos)
Parameters:
  • decoder (TransformerDecoder) – decoder module

  • encoder (nn.Module) – encoder module

  • decoder_trainable (bool) – whether to train or freeze the decoder. Default is False.

  • encoder_trainable (bool) – whether to train or freeze the encoder. Default is False.

  • fusion_trainable (bool) – whether to train the fusion parameters. Default is True.

caches_are_enabled() bool[source]

Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant attention modules will be “enabled” and all forward passes will update the caches. This behaviour can be disabled without altering the state of the KV-caches by “disabling” the KV-caches using disable_kv_cache(), upon which caches_are_enabled would return False.

caches_are_setup() bool[source]

Check if the key value caches are setup. This means setup_caches has been called, and the relevant attention modules in the model have created their KVCache.

forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Dict] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]][source]
Parameters:
  • tokens (torch.Tensor) – input tensor with shape [b x s]

  • mask (Optional[torch.Tensor]) – Optional boolean tensor which contains the attention mask with shape [b x s x s]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None.

  • encoder_input (Optional[Dict]) – Optional input for the encoder.

  • encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None.

  • input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.

Note: At the very first step of inference, when the model is provided with a prompt, input_pos would contain the positions of all of the tokens in the prompt (eg: torch.arange(prompt_length)). This is because we will need to compute the KV values for each position.

Returns:

output tensor with shape [b x s x v] or a list of layer output tensors defined by output_hidden_states with the final output tensor appended to the list.

Return type:

Tensor

Notation used for tensor shapes:
  • b: batch size

  • s: token sequence length

  • s_e: encoder sequence length

  • v: vocab size

  • d: token embed dim

  • d_e: encoder embed dim

  • m_s: max seq len

reset_caches()[source]

Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, without deleting or reallocating cache tensors.

set_num_output_chunks(num_output_chunks: int) None[source]

Used to save memory in combination with CEWithChunkedOutputLoss. This should be called before the first forward pass, in the recipe.

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: int = None, decoder_max_seq_len: int = None)[source]

Sets up key-value attention caches for inference for self.decoder. For each layer in self.decoder.layers: - torchtune.modules.TransformerSelfAttentionLayer will use decoder_max_seq_len. - torchtune.modules.TransformerCrossAttentionLayer will use encoder_max_seq_len. - torchtune.modules.fusion.FusionLayer will use both decoder_max_seq_len and encoder_max_seq_len.

Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

  • encoder_max_seq_len (int) – maximum encoder cache sequence length.

  • decoder_max_seq_len (int) – maximum decoder cache sequence length.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources