torchtune.modules¶
Modeling Components and Building Blocks¶
Multi-headed attention layer with support for grouped query attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. |
|
This class implements the feed-forward network derived from Llama2. |
|
Standalone |
|
This class implements Rotary Positional Embeddings (RoPE) proposed in https://arxiv.org/abs/2104.09864. |
|
Root Mean Square Normalization in fp32. |
|
Wrapper around |
|
Implements a basic learnable gate to scale layer outputs |
|
A tied linear layer, without bias, that shares the same weight as another linear layer. |
|
Transformer layer derived from the Llama2 model. |
|
Cross attention Transformer layer following the same conventions as the TransformerSelfAttentionLayer. |
|
Transformer Decoder derived from the Llama2 architecture. |
|
Implementation of the ViT architecture (https://arxiv.org/abs/2010.11929), with support for tile-cropped images, outputting of hidden layers and optional CLS projection. |
|
A module that applies layer dropout to the input tensor of an underlying module. |
|
Prepare a model's layers for layer dropout by wrapping each layer with a ModuleLayerDropoutWrapper. |
Losses¶
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time. |
|
The Kullback-Leibler divergence loss for valid indexes. |
|
Forward KL with chunked outputs that saves memory by only upcasting one chunk at a time. |
Base Tokenizers¶
Base tokenizers are tokenizer models that perform the direct encoding of text into token IDs and decoding of token IDs into text. These are typically byte pair encodings that underlie the model specific tokenizers.
A light-weight wrapper around SentencePieceProcessor that additionally handles trimming leading whitespaces. |
|
A lightweight wrapper around tiktoken Encoding. |
|
Abstract tokenizer that implements model-specific special token logic in the |
|
Abstract token encoding model that implements |
Tokenizer Utilities¶
These are helper methods that can be used by any tokenizer.
Tokenize a list of messages one at a time then concatenate them, returning a list of tokens and a list of masks. |
|
Parse the |
PEFT Components¶
LoRA linear layer as introduced in LoRA: Low-Rank Adaptation of Large Language Models. |
|
DoRA linear layer as introduced in DoRA: Weight-Decomposed Low-Rank Adaptation of Large Language Models. |
|
Interface for an |
|
Return the subset of parameters from a model that correspond to an adapter. |
|
Set trainable parameters for an nn.Module based on a state dict of adapter parameters. |
|
Return the subset of the full state_dict from a model that correspond to an adapter. |
|
A more memory-efficient way to validate that LoRA state dict loading was done properly. |
|
Temporarily disable the adapters in a model. |
Fusion Components¶
Components for building models that are a fusion of two+ pre-trained models.
DeepFusion is a type of fused model architecture where a pretrained encoder is combined with a pretrained decoder (LLM) in the internal decoder layers. |
|
Fusion layer as introduced in Flamingo: a Visual Language Model for Few-Shot Learning. |
|
Fusion embedding supports training additional special tokens while keeping the original embedding frozen. |
|
Add the method fusion_params to an nn.Module that marks all of the Modules parameters as fusion params. |
|
Return the subset of parameters from a model that correspond to fused modules. |
Module Utilities¶
These are utilities that are common to and can be used by all modules.
A state_dict hook that replaces NF4 tensors with their restored higher-precision weight and optionally offloads the restored weight to CPU. |
|
This context manager temporarily enables KV-cacheing on a given model, which does not already have KV-caches setup. |
|
This context manager temporarily disables KV-cacheing on a given model, which must already already have KV-caches setup. |
|
Deletes KV caches from all attention layers in a model, and also ensures |
Vision Transforms¶
Functions used for preprocessing images.
Loose interface for all data and model transforms. |
|
Computes the cross-attention mask for text + image inputs. |