FusionEmbedding¶
- class torchtune.modules.model_fusion.FusionEmbedding(vocab_size: int, fusion_vocab_size: int, embed_dim: int)[source]¶
Fusion embedding supports training additional special tokens while keeping the original embedding frozen. When fusing new models with a language model, there may be some additional tokens needed to support the fused language model. For example, adding a vision encoder might necessitate additional tokens like
<|image|>
to indicate an images position in text and require learning an embedding for this token. The FusionEmbedding keeps the original embeddings frozen while learning a much smaller second embedding for the additional tokens. During forward this module routes the tokens to the appropriate embedding table.Use this as a drop-in replacement for
torch.nn.Embedding
in your model.Example
>>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128) >>> model = TransformerDecoder(tok_embeddings=embedding, ...) >>> >>> # Original model state_dict still works >>> model.load_state_dict(..., strict=False)
Note
This module assumes all tokens in the range [0, vocab_size) are part of the original embedding table and all new tokens in the range [vocab_size, vocab_size + fusion_vocab_size)
- Parameters:
- forward(input: Tensor) Tensor [source]¶
- Parameters:
input (torch.Tensor) – input integer tensor with shape [batch_size x seq_length]
- Returns:
- output tensor embedding with shape
[batch_size x seq_length x embed_dim]`
- Return type:
Tensor