Memory Optimization Overview¶
Author: Salman Mohammadi
torchtune comes with a host of plug-and-play memory optimization components which give you lots of flexibility
to tune
our recipes to your hardware. This page provides a brief glossary of these components and how you might use them.
To make things easy, we’ve summarized these components in the following table:
Component |
When to use? |
---|---|
You’ll usually want to leave this as its default |
|
Use when you’re memory constrained and need to handle larger batch sizes or longer context lengths. Be aware that it may slow down training speed. |
|
Helpful when memory-constrained to simulate larger batch sizes. Often preferable to activation checkpointing for better training speed. |
|
When you need to further reduce memory usage beyond using |
|
Helps reduce memory usage when using stateful optimizers, particularly when full-finetuning large models with high gradient memory usage. This is not compatible with |
|
When you want to significantly reduce the number of trainable parameters, saving gradient and optimizer memory during training, and significantly speeding up training. |
|
When you need even more memory savings than LoRA, at the potential cost of some training speed. Useful for very large models or limited hardware. |
Note
In its current state, this tutorial is focused on single-device optimizations. Check in soon as we update this page for the latest memory optimization features for distributed fine-tuning.
Model Precision¶
What’s going on here?
We use the term “precision” to refer to the underlying data type used to represent the model and optimizer parameters. We support two data types in torchtune:
Note
We recommend diving into Sebastian Raschka’s blogpost on mixed-precision techniques for a deeper understanding of concepts around precision and data formats.
fp32
, commonly referred to as “full-precision”, uses 4 bytes per model and optimizer parameter.bfloat16
, referred to as “half-precision”, uses 2 bytes per model and optimizer parameter - effectively half the memory offp32
, and also improves training speed. Generally, if your hardware supports training withbfloat16
, we recommend using it - this is the default setting for our recipes.
Note
Another common paradigm is “mixed-precision” training: where model weights are in bfloat16
(or fp16
), and optimizer
states are in fp32
. Currently, we don’t support mixed-precision training in torchtune.
Sounds great! How do I use it?
Simply use the dtype
flag or config entry in all our recipes! For example, to use half-precision training in bf16
,
set dtype=bf16
.
Activation Checkpointing¶
What’s going on here?
The relevant section in the PyTorch documentation explains this concept well. To quote:
Activation checkpointing is a technique that trades compute for memory. Instead of keeping tensors needed for backward alive until they are used in gradient computation during backward, forward computation in checkpointed regions omits saving tensors for backward and recomputes them during the backward pass.
This setting is helpful for when you’re memory-constrained, especially due to larger batch sizes or longer context lengths. However, these savings in memory come at the cost of training speed (i.e. tokens-per-second), and in most cases training can slow down quite a bit as a result of this activation recomputation.
Sounds great! How do I use it?
To enable activation checkpointing, use the enable_activation_checkpointing
config entry or flag
in any of our recipes, e.g. enable_activation_checkpointing=True
.
Activation Offloading¶
What’s going on here?
You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing them back when needed in the backward pass.
See PyTorch autograd hook tutorial for more details about how this is implemented through saved_tensors_hooks.
This setting is especially helpful for larger batch sizes, or longer context lengths when you’re memory constrained.
However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime
and resources to move Tensors from GPU to CPU and back. The implementation in torchtune has the offload_with_streams
option to use multiple CUDA streams in order to overlap the extra communication with the computation to hide the extra
runtime. As the communication workload is variable depending on the number and size of tensors being offloaded, it is
common to not offload every single activation. In fact, once can use offloading in conjunction with activations
checkpointing, where all activations will either be recomputed later in the backward or brought back from the CPU.
Sounds great! How do I use it?
To enable activation offloading, use the enable_activation_offloading
config entry or flag
in our lora finetuning single device recipe, e.g. enable_activation_offloading=True
. To allow
usage of streams, make sure you are on a torch version later than PyTorch 2.5.0.dev20240907 and
specify offload_with_streams=True
.
Gradient Accumulation¶
What’s going on here?
Gradient accumulation allows you to simulate large batch sizes by accumulating gradients over several batches before updating model parameters using the optimizer. Concretely, the total number of samples used for a gradient update is when using gradient accumulation is:
total_batch_size = batch_size * gradient_accumulation_steps
For example: with batch_size=1
and gradient_accumulation_steps=32
we get a total batch size of 32.
Note
For other components in torchtune which use “steps”, such as metric logging, or
learning rate schedulers
, a “step” is counted as a
single update to model parameters, rather than a single model forward pass with the data.
Suppose gradient_accumulation_steps = 4
and log_every_n_steps = 10
.
Metrics would be logged every 10 global steps, which translates to every 40 model forward passes.
For this reason, metric logging will appear less frequently when training with gradient accumulation,
and progress bars may update more slowly.
If you’re using one of our distributed recipes, simply multiply by the number of devices:
total_batch_size = batch_size * gradient_accumulation_steps * num_devices
Gradient accumulation is especially useful when you are memory constrained. In this case, accumulating gradients might give you better training speed than enabling activation checkpointing, since activation checkpointing reduces memory consumption at the cost of repeated computations.
Sounds great! How do I use it?
All of our finetuning recipes support simulating larger batch sizes by accumulating gradients. Just set the
gradient_accumulation_steps
flag or config entry.
Note
Gradient accumulation should always be set to 1 when fusing the optimizer step into the backward pass.
Lower Precision Optimizers¶
What’s going on here?
In addition to reducing model and optimizer precision during training, we can further reduce precision in our optimizer states.
All of our single-device fine-tuning recipes support lower-precision optimizers from the bitsandbytes library -
a good place to start might be the AdamW8bit
and PagedAdamW8bit
optimizers, which we’ve tested our recipes with.
Sounds great! How do I use it?
To use this in your recipes, make sure you have installed bitsandbytes (pip install bitsandbytes
). Then, enable
a low precision optimizer using the torchtune CLI:
tune run <RECIPE> --config <CONFIG> \
optimizer=bitsandbytes.optim.PagedAdamW
or by directly modifying a config file:
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
Fusing Optimizer Step into Backward Pass¶
What’s going on here?
Stateful optimizers (e.g. optimizers which use momentum) are the default in modern deep learning due to their stable convergence properties. However, maintaining a state of gradient statistics comes at the cost of additional memory usage. An immediate alternative might be to turn to stateless optimizers such as stochastic gradient descent without momentum, which don’t require any additional memory usage, but will likely result in worse convergence during training.
Can we find a middle ground here? Let’s consider a technique which enables the use of “stateful” optimizers such as AdamW
without the memory overhead of gradient statistics, and without sacrificing their desirable convergence properties.
How is this possible, you might ask? By completely removing the buffer of gradients which are stored by the optimizer during its step()
.
To understand how this works, we encourage you to read through the relevant PyTorch tutorial on this concept: How to save memory by fusing the optimizer step into the backward pass.
Sounds great! How do I use it?
In torchtune, you can enable this feature using the optimizer_in_bwd
flag, which is currently only supported in our
single-device full finetune recipe. This feature works best when gradient memory is particularly large;
e.g. when using a stateful optimizer with a model with a lot of parameters, and when you don’t need to use
gradient accumulation.
Parameter Efficient Fine-Tuning (PEFT)¶
Low Rank Adaptation (LoRA)¶
What’s going on here?
You can read our tutorial on finetuning Llama2 with LoRA to understand how LoRA works, and how to use it. Simply stated, LoRA greatly reduces the number of trainable parameters, thus saving significant gradient and optimizer memory during training.
Sounds great! How do I use it?
You can finetune using any of our recipes with the lora_
prefix, e.g. lora_finetune_single_device. These recipes utilize
LoRA-enabled model builders, which we support for all our models, and also use the lora_
prefix, e.g.
the torchtune.models.llama3.llama3()
model has a corresponding torchtune.models.llama3.lora_llama3()
.
We aim to provide a comprehensive set of configurations to allow you to get started with training with LoRA quickly,
just specify any config with _lora
in its name, e.g:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device
There are two sets of parameters to customize LoRA to suit your needs. Firstly, the parameters which control which linear layers LoRA should be applied to in the model:
lora_attn_modules: List[str]
accepts a list of strings specifying which layers of the model to apply LoRA to:q_proj
applies LoRA to the query projection layer.k_proj
applies LoRA to the key projection layer.v_proj
applies LoRA to the value projection layer.output_proj
applies LoRA to the attention output projection layer.
Whilst adding more layers to be fine-tuned may improve model accuracy, this will come at the cost of increased memory usage and reduced training speed.
apply_lora_to_mlp: Bool
applies LoRA to the MLP in each transformer layer.apply_lora_to_output: Bool
applies LoRA to the model’s final output projection. This is usually a projection to vocabulary space (e.g. in language models), but other modelling tasks may have different projections - classifier models will project to the number of classes, for example
Note
Models which use tied embeddings (such as Gemma and Qwen2 1.5B and 0.5B) for the
final output projection do not support apply_lora_to_output
.
These are all specified under the model
flag or config entry, i.e:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"]
model:
apply_lora_to_mlp: True
model.lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
Secondly, parameters which control the scale of the impact of LoRA on the model:
lora_rank: int
affects the scale of the LoRA decomposition, wherelora_rank << in_dim
andlora_rank << out_dim
- the dimensions of an arbitrary linear layer in the model. Concretely,lora_rank
reduces the number of gradients stored in a linear fashion fromin_dim * out_dim
tolora_rank * (in_dim + out_dim)
. Typically, we havelora_rank in [8, 128]
.lora_alpha: float
affects the magnitude of the LoRA updates. A larger alpha results in larger updates to the base model weights , potentially at the cost of training stability, conversely, smaller alpha can stabilize training at the cost of slower learning. We provide default settings for these parameters which we’ve tested with all of our models, but we encourage you to adjust them to your specific use case. Typically, one jointly changeslora_rank
andlora_alpha
together, wherelora_alpha ~= 2*lora_rank
.lora_dropout
introduces dropout in the LoRA layers to help regularize training. We default to 0.0 for all of our models.
As above, these parameters are also specified under the model
flag or config entry.
Note
To get a deeper sense of how LoRA parameters affect memory usage during training, see the relevant section in our Llama2 LoRA tutorial.
Quantized Low Rank Adaptation (QLoRA)¶
What’s going on here?
QLoRA is an enhancement on top of LoRA that maintains the frozen model parameters from LoRA in 4-bit quantized precision, thereby reducing memory usage. This is enabled through a novel 4-bit NormalFloat (NF4) data type proposed by the authors, which allows for 4-8x less parameter memory usage whilst retaining model accuracy. You can read our tutorial on finetuning Llama2 with QLoRA for a deeper understanding of how it works.
When considering using QLoRA to reduce memory usage, it’s worth noting that QLoRA prevents accuracy degradation during quantization
by up-casting quantized parameters to the original higher precision datatype during model forward passes - this up-casting may
incur penalties to training speed. The relevant section in our QLoRA tutorial demonstrates the usage of torch.compile
to address this by speeding up training.
Sounds great! How do I use it?
You can finetune using QLoRA with any of our LoRA recipes, i.e. recipes with the lora_
prefix, e.g. lora_finetune_single_device. These recipes utilize
QLoRA-enabled model builders, which we support for all our models, and also use the qlora_
prefix, e.g.
the torchtune.models.llama3.llama3_8b()
model has a corresponding torchtune.models.llama3.qlora_llama3_8b()
.
We aim to provide a comprehensive set of configurations to allow you to get started with training with QLoRA quickly,
just specify any config with _qlora
in its name, e.g:
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
All the rest of the LoRA parameters remain the same for QLoRA - check out the section above on LoRA to see how to configure.