Shortcuts

FusionEmbedding

class torchtune.modules.model_fusion.FusionEmbedding(vocab_size: int, fusion_vocab_size: int, embed_dim: int)[source]

Fusion embedding supports training additional special tokens while keeping the original embedding frozen. When fusing new models with a language model, there may be some additional tokens needed to support the fused language model. For example, adding a vision encoder might necessitate additional tokens like <|image|> to indicate an images position in text and require learning an embedding for this token. The FusionEmbedding keeps the original embeddings frozen while learning a much smaller second embedding for the additional tokens. During forward this module routes the tokens to the appropriate embedding table.

Use this as a drop-in replacement for torch.nn.Embedding in your model.

Example

>>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128)
>>> model = TransformerDecoder(tok_embeddings=embedding, ...)
>>>
>>> # Original model state_dict still works
>>> model.load_state_dict(..., strict=False)

Note

This module assumes all tokens in the range [0, vocab_size) are part of the original embedding table and all new tokens in the range [vocab_size, vocab_size + fusion_vocab_size)

Parameters:
  • vocab_size (int) – language model vocab size

  • fusion_vocab_size (int) – additional tokens for the fused model

  • embed_dim (int) – embedding dimension of the two embedding tables

forward(input: Tensor) Tensor[source]
Parameters:

input (torch.Tensor) – input integer tensor with shape [batch_size x seq_length]

Returns:

output tensor embedding with shape

[batch_size x seq_length x embed_dim]`

Return type:

Tensor

fusion_params() List[str][source]

Return fusion embedding parameters.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources