Shortcuts

FusionLayer

class torchtune.modules.model_fusion.FusionLayer(layer: Module, fusion_layer: Module, fusion_first: bool = True)[source]

Fusion layer as introduced in Flamingo: a Visual Language Model for Few-Shot Learning.

Deep Fusion model architectures combine pretrained encoder models with pretrained language models by infusing the encoder outputs into the middle layers of the LLM. This allows the language model to interpret the enocder outputs as text and “understand” any modality for which you can train an encoder. To enable the language model to adapt to the encoder outputs, the FusionLayer fuses a new learnable layer to an existing decoder (language model) layer. This additional layer can take the encoder embeddings and learn to combine them with the token embeddings from the decoder. The module supports fusing the new layer before or after the original, in Flamingo the new layer is fused before the original.

The original layer is wrapped in FusionLayer such that it maintains its original state_dict key and the pre-trained checkpoint isn’t broken. The new layer parameters are available through fusion_params to separately control if they’re trainable or not.

Example

>>> # Original decoder style transformer
>>> layer = nn.TransformerSelfAttentionLayer(...)
>>> model = TransformerDecoder(layers=layer, num_layers=32, ...)
>>>
>>> # Fuse a cross attention layer to each self attention layer to adapt for the encoder
>>> fusion_layer = nn.TransformerCrossAttentionLayer(...)
>>> fused_layer = FusionLayer(layer, fusion_layer)
>>> model = TransformerDecoder(layers=fused_layer, num_layers=32, ...)
>>>
>>> # Original decoder state_dict still works
>>> model.load_state_dict(..., strict=False)
Parameters:
  • layer (nn.Module) – original decoder layer

  • fusion_layer (nn.Module) – new fusion layer

  • fusion_first (bool) – boolean to insert fusion layer before or after the decoder layer.

caches_are_enabled() bool[source]

Checks if the key value caches on self.layer are enabled. See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.

caches_are_setup() bool[source]

Check if the key value caches are setup on self.layer. See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.

forward(x: Tensor, **kwargs: Dict) Tensor[source]
Parameters:
  • x (torch.Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]

  • **kwargs (Dict) – all additional layer args

Returns:

output tensor with same shape as input

[batch_size x seq_length x embed_dim]`

Return type:

Tensor

fusion_params() List[str][source]

Return parameters of fusion layer.

reset_cache()[source]

Reset both layers’ key value caches.

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: int, decoder_max_seq_len: int) None[source]

Setup key value cache for both layers.

Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

  • encoder_max_seq_len (int) – maximum cache sequence length for cross-attention layer.

  • decoder_max_seq_len (int) – maximum cache sequence length for self-attention layer.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources