Memory Optimization Overview¶
Author: Salman Mohammadi
torchtune comes with a host of plug-and-play memory optimization components which give you lots of flexibility
to tune
our recipes to your hardware. This page provides a brief glossary of these components and how you might use them.
To make things easy, we’ve summarized these components in the following table:
Component |
When to use? |
---|---|
You’ll usually want to leave this as its default |
|
Use when you’re memory constrained and want to use a larger model, batch size or context length. Be aware that it will slow down training speed. |
|
Similar to activation checkpointing, this can be used when memory constrained, but may decrease training speed. This should be used alongside activation checkpointing. |
|
Helpful when memory-constrained to simulate larger batch sizes. Not compatible with optimizer in backward. Use it when you can already fit at least one sample without OOMing, but not enough of them. |
|
Use when you want to reduce the size of the optimizer state. This is relevant when training large models and using optimizers with momentum, like Adam. Note that lower precision optimizers may reduce training stability/accuracy. |
|
Use it when you have large gradients and can fit a large enough batch size, since this is not compatible with |
|
Offloads optimizer states and (optionally) gradients to CPU, and performs optimizer steps on CPU. This can be used to significantly reduce GPU memory usage at the cost of CPU RAM and training speed. Prioritize using it only if the other techniques are not enough. |
|
When you want to significantly reduce the number of trainable parameters, saving gradient and optimizer memory during training, and significantly speeding up training. This may reduce training accuracy |
|
When you are training a large model, since quantization will save 1.5 bytes * (# of model parameters), at the potential cost of some training speed and accuracy. |
|
a variant of LoRA that may improve model performance at the cost of slightly more memory. |
Note
In its current state, this tutorial is focused on single-device optimizations. Check in soon as we update this page for the latest memory optimization features for distributed fine-tuning.
Model Precision¶
What’s going on here?
We use the term “precision” to refer to the underlying data type used to represent the model and optimizer parameters. We support two data types in torchtune:
Note
We recommend diving into Sebastian Raschka’s blogpost on mixed-precision techniques for a deeper understanding of concepts around precision and data formats.
fp32
, commonly referred to as “full-precision”, uses 4 bytes per model and optimizer parameter.bfloat16
, referred to as “half-precision”, uses 2 bytes per model and optimizer parameter - effectively half the memory offp32
, and also improves training speed. Generally, if your hardware supports training withbfloat16
, we recommend using it - this is the default setting for our recipes.
Note
Another common paradigm is “mixed-precision” training: where model weights are in bfloat16
(or fp16
), and optimizer
states are in fp32
. Currently, we don’t support mixed-precision training in torchtune.
Sounds great! How do I use it?
Simply use the dtype
flag or config entry in all our recipes! For example, to use half-precision training in bf16
,
set dtype=bf16
.
Activation Checkpointing¶
What’s going on here?
The relevant section in the PyTorch documentation explains this concept well. To quote:
Activation checkpointing is a technique that trades compute for memory. Instead of keeping tensors needed for backward alive until they are used in gradient computation during backward, forward computation in checkpointed regions omits saving tensors for backward and recomputes them during the backward pass.
This setting is helpful for when you’re memory-constrained, especially due to larger batch sizes or longer context lengths. However, these savings in memory come at the cost of training speed (i.e. tokens-per-second), and in most cases training can slow down quite a bit as a result of this activation recomputation.
Sounds great! How do I use it?
To enable activation checkpointing, use enable_activation_checkpointing=True
.
Activation Offloading¶
What’s going on here?
You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing them back when needed in the backward pass.
See PyTorch autograd hook tutorial
for more details about how this is implemented through torch.autograd.graph.saved_tensors_hooks()
.
This setting is especially helpful for larger batch sizes, or longer context lengths when you’re memory constrained. While of course it takes runtime and resources to move Tensors from GPU to CPU and back, the implementation in torchtune uses multiple CUDA streams (when available) in order to overlap the extra communication with the computation to hide the extra runtime. As the communication workload is variable depending on the number and size of tensors being offloaded, we do not recommend using it unless Activation Checkpointing is also enabled, in which case only the checkpointed tensors will be offloaded.
Sounds great! How do I use it?
To enable activation offloading, use enable_activation_offloading=True
. If you are on torch
version later than PyTorch 2.5.0, it will allow the usage of multiple CUDA streams automatically.
Gradient Accumulation¶
What’s going on here?
Gradient accumulation allows you to simulate large batch sizes by accumulating gradients over several batches before updating model parameters using the optimizer. Concretely, the total number of samples used for a gradient update is when using gradient accumulation is:
total_batch_size = batch_size * gradient_accumulation_steps
For example: with batch_size=1
and gradient_accumulation_steps=32
we get a total batch size of 32.
Note
For other components in torchtune which use “steps”, such as metric logging, or
learning rate schedulers
, a “step” is counted as a
single update to model parameters, rather than a single model forward pass with the data.
Suppose gradient_accumulation_steps = 4
and log_every_n_steps = 10
.
Metrics would be logged every 10 global steps, which translates to every 40 model forward passes.
For this reason, metric logging will appear less frequently when training with gradient accumulation,
and progress bars may update more slowly.
If you’re using one of our distributed recipes, simply multiply by the number of devices:
total_batch_size = batch_size * gradient_accumulation_steps * num_devices
Gradient accumulation is especially useful when you can fit at least one sample in your GPU. In this case, artificially increasing the batch by accumulating gradients might give you faster training speeds than using other memory optimization techniques that trade-off memory for speed, like activation checkpointing.
Sounds great! How do I use it?
All of our finetuning recipes support simulating larger batch sizes by accumulating gradients. Just set the
gradient_accumulation_steps
flag or config entry.
Note
Gradient accumulation should always be set to 1 when fusing the optimizer step into the backward pass.
Optimizers¶
Lower Precision Optimizers¶
What’s going on here?
In addition to reducing model and optimizer precision during training, we can further reduce precision in our optimizer states. All of our recipes support lower-precision optimizers from the torchao library. For single device recipes, we also support bitsandbytes.
A good place to start might be the torchao.prototype.low_bit_optim.torchao.AdamW8bit
and bitsandbytes.optim.PagedAdamW8bit
optimizers.
Both reduce memory by quantizing the optimizer state dict. Paged optimizers will also offload to CPU if there isn’t enough GPU memory available. In practice,
you can expect higher memory savings from bnb’s PagedAdamW8bit but higher training speed from torchao’s AdamW8bit.
Sounds great! How do I use it?
To use this in your recipes, make sure you have installed torchao (pip install torchao
) or bitsandbytes (pip install bitsandbytes
). Then, enable
a low precision optimizer using the torchtune CLI:
tune run <RECIPE> --config <CONFIG> \
optimizer=torchao.prototype.low_bit_optim.torchao.AdamW8bit
tune run <RECIPE> --config <CONFIG> \
optimizer=bitsandbytes.optim.PagedAdamW8bit
or by directly modifying a config file:
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
Fusing Optimizer Step into Backward Pass¶
What’s going on here?
Stateful optimizers (e.g. optimizers which use momentum) are the default in modern deep learning due to their stable convergence properties. However, maintaining a state of gradient statistics comes at the cost of additional memory usage. An immediate alternative might be to turn to stateless optimizers such as stochastic gradient descent without momentum, which don’t require any additional memory usage, but will likely result in worse convergence during training.
Can we find a middle ground here? Let’s consider a technique which enables the use of “stateful” optimizers such as AdamW
without the memory overhead of gradient statistics, and without sacrificing their desirable convergence properties.
How is this possible, you might ask? By completely removing the buffer of gradients which are stored by the optimizer during its step()
.
To understand how this works, we encourage you to read through the relevant PyTorch tutorial on this concept: How to save memory by fusing the optimizer step into the backward pass.
Sounds great! How do I use it?
In torchtune, you can enable this feature using the optimizer_in_bwd
flag. This feature works best when using a stateful optimizer
with a model with a lot of parameters, and when you don’t need to use gradient accumulation.
You won’t see meaningful impact when finetuning LoRA recipes, since in this case the number of parameters being updated are small.
Offloading Optimizer/Gradient states to CPU¶
What’s going on here?
We’ve mentioned above the concept of optimizer states - memory used by the stateful optimizers to maintain a state of gradient statistics, and
model gradients - tensors used to store gradients when we perform model backwards passes. We support using CPU offloading in our single-device recipes
through the CPUOffloadOptimizer from torchao
.
This optimizer can wrap any base optimizer and works by keeping the optimizer states and performing the optimizer step on CPU, thus reducing GPU memory usage by the size of the optimizer states. Additionally, we can also offload gradients to the CPU by using offload_gradients=True.
If finetuning on a single-device, another option is to use the PagedAdamW8bit
from bitsandbytes, mentioned above, which will only offload to CPU
when there is not enough GPU available.
Sounds great! How do I use it?
To use this optimizer in your recipes, set the optimizer
key in your config to torchao.prototype.low_bit_optim.CPUOffloadOptimizer
, which
will use the torch.optim.AdamW
optimizer with fused=True
as the base optimizer. For example, to use this optimizer to offload
both optimizer states and gradients to CPU:
tune run <RECIPE> --config <CONFIG> \
optimizer=optimizer=torchao.prototype.low_bit_optim.CPUOffloadOptimizer \
optimizer.offload_gradients=True \
lr=4e-5
or by directly modifying a config file:
optimizer:
_component_: torchao.prototype.low_bit_optim.CPUOffloadOptimizer
offload_gradients: True
# additional key-word arguments can be passed to torch.optim.AdamW
lr: 4e-5
or using it directly in your code, which allows you to change the base optimizer:
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
from torch.optim import Adam
optimizer = CPUOffloadOptimizer(
model.parameters(), # your model here
Adam,
lr=1e-5,
fused=True
)
Some helpful hints from the torchao
CPUOffloadOptimizer page:
The CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) use full
bf16
training so that parameters, gradients, and optimizer states are inbf16
; and (2) give GPU more work per optimizer step to amortize the offloading time (e.g. larger batch size with activation checkpointing, gradient accumulation).Gradient accumulation should always be set to 1 when
offload_gradients=True
, as gradients are cleared on GPU every backward pass.This optimizer works by keeping a copy of parameters and pre-allocating gradient memory on CPU. Therefore, expect your RAM usage to increase by 4x model size.
This optimizer is only supported for single-device recipes. To use CPU-offloading in distributed recipes, use
fsdp_cpu_offload=True
instead. Seetorch.distributed.fsdp.FullyShardedDataParallel
for more details and FSDP1 vs FSDP2 to see how they differ.
Parameter Efficient Fine-Tuning (PEFT)¶
Low Rank Adaptation (LoRA)¶
What’s going on here?
You can read our tutorial on finetuning Llama2 with LoRA to understand how LoRA works, and how to use it. Simply stated, LoRA greatly reduces the number of trainable parameters, thus saving significant gradient and optimizer memory during training.
Sounds great! How do I use it?
You can finetune using any of our recipes with the lora_
prefix, e.g. lora_finetune_single_device. These recipes utilize
LoRA-enabled model builders, which we support for all our models, and also use the lora_
prefix, e.g.
the torchtune.models.llama3.llama3()
model has a corresponding torchtune.models.llama3.lora_llama3()
.
We aim to provide a comprehensive set of configurations to allow you to get started with training with LoRA quickly,
just specify any config with _lora
in its name, e.g:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device
There are two sets of parameters to customize LoRA to suit your needs. Firstly, the parameters which control which linear layers LoRA should be applied to in the model:
lora_attn_modules: List[str]
accepts a list of strings specifying which layers of the model to apply LoRA to:q_proj
applies LoRA to the query projection layer.k_proj
applies LoRA to the key projection layer.v_proj
applies LoRA to the value projection layer.output_proj
applies LoRA to the attention output projection layer.
Whilst adding more layers to be fine-tuned may improve model accuracy, this will come at the cost of increased memory usage and reduced training speed.
apply_lora_to_mlp: Bool
applies LoRA to the MLP in each transformer layer.apply_lora_to_output: Bool
applies LoRA to the model’s final output projection. This is usually a projection to vocabulary space (e.g. in language models), but other modelling tasks may have different projections - classifier models will project to the number of classes, for example
Note
Models which use tied embeddings (such as Gemma and Qwen2 1.5B and 0.5B) for the
final output projection do not support apply_lora_to_output
.
These are all specified under the model
flag or config entry, i.e:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj","output_proj"]
model:
_component_: torchtune.models.llama3.lora_llama3_8b
apply_lora_to_mlp: True
model.lora_attn_modules: ["q_proj", "k_proj", "v_proj","output_proj"]
Secondly, parameters which control the scale of the impact of LoRA on the model:
lora_rank: int
affects the scale of the LoRA decomposition, wherelora_rank << in_dim
andlora_rank << out_dim
- the dimensions of an arbitrary linear layer in the model. Concretely,lora_rank
reduces the number of gradients stored in a linear fashion fromin_dim * out_dim
tolora_rank * (in_dim + out_dim)
. Typically, we havelora_rank in [8, 256]
.lora_alpha: float
affects the magnitude of the LoRA updates. A larger alpha results in larger updates to the base model weights , potentially at the cost of training stability, conversely, smaller alpha can stabilize training at the cost of slower learning. We provide default settings for these parameters which we’ve tested with all of our models, but we encourage you to adjust them to your specific use case. Typically, one jointly changeslora_rank
andlora_alpha
together, wherelora_alpha ~= 2*lora_rank
.lora_dropout
introduces dropout in the LoRA layers to help regularize training. We default to 0.0 for all of our models.
As above, these parameters are also specified under the model
flag or config entry:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj","output_proj"] \
model.lora_rank=32 \
model.lora_alpha=64
model:
_component_: torchtune.models.llama3.lora_llama3_8b
apply_lora_to_mlp: True
lora_attn_modules: ["q_proj", "k_proj", "v_proj","output_proj"]
lora_rank: 32
lora_alpha: 64
Note
To get a deeper sense of how LoRA parameters affect memory usage during training, see the relevant section in our Llama2 LoRA tutorial.
Quantized Low Rank Adaptation (QLoRA)¶
What’s going on here?
QLoRA is a memory enhancement on top of LoRA that maintains the frozen model parameters from LoRA in 4-bit quantized precision, thereby reducing memory usage. This is enabled through a novel 4-bit NormalFloat (NF4) data type proposed by the authors, which allows for 4-8x less parameter memory usage whilst retaining model accuracy. You can read our tutorial on finetuning Llama2 with QLoRA for a deeper understanding of how it works.
When considering using QLoRA to reduce memory usage, it’s worth noting that QLoRA is slower than LoRA and may not be worth it if
the model you are finetuning is small. In numbers, QLoRA saves roughly 1.5 bytes * (# of model parameters). Also, although QLoRA quantizes the model,
it minimizes accuracy degradation by up-casting quantized parameters to the original higher precision datatype during model forward passes - this up-casting may incur penalties to training speed.
The relevant section in our QLoRA tutorial demonstrates the usage of torch.compile
to address this by speeding up training.
Sounds great! How do I use it?
You can finetune using QLoRA with any of our LoRA recipes, i.e. recipes with the lora_
prefix, e.g. lora_finetune_single_device. These recipes utilize
QLoRA-enabled model builders, which we support for all our models, and also use the qlora_
prefix, e.g.
the torchtune.models.llama3.llama3_8b()
model has a corresponding torchtune.models.llama3.qlora_llama3_8b()
.
We aim to provide a comprehensive set of configurations to allow you to get started with training with QLoRA quickly,
just specify any config with _qlora
in its name.
All the rest of the LoRA parameters remain the same for QLoRA - check out the section above on LoRA to see how to configure these parameters.
To configure from the command line:
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"] \
model.lora_rank=32 \
model.lora_alpha=64
or, by modifying a config:
model:
_component_: torchtune.models.qlora_llama3_8b
apply_lora_to_mlp: True
lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
lora_rank: 32
lora_alpha: 64
Weight-Decomposed Low-Rank Adaptation (DoRA)¶
What’s going on here?
DoRA is another PEFT technique which builds on-top of LoRA by further decomposing the pre-trained weights into two components: magnitude and direction. The magnitude component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA decomposition and updates the orientation of weights.
DoRA adds a small overhead to LoRA training due to the addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low ranks.
Sounds great! How do I use it?
Much like LoRA and QLoRA, you can finetune using DoRA with any of our LoRA recipes. We use the same model builders for LoRA
as we do for DoRA, so you can use the lora_
version of any model builder with use_dora=True
. For example, to finetune
torchtune.models.llama3.llama3_8b()
with DoRA, you would use torchtune.models.llama3.lora_llama3_8b()
with use_dora=True
:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.use_dora=True
model:
_component_: torchtune.models.lora_llama3_8b
use_dora: True
Since DoRA extends LoRA, the parameters for customizing LoRA are identical. You can also quantize the base model weights like in Quantized Low Rank Adaptation (QLoRA) by using quantize=True
to reap
even more memory savings!
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"] \
model.lora_rank=16 \
model.lora_alpha=32 \
model.use_dora=True \
model.quantize_base=True
model:
_component_: torchtune.models.lora_llama3_8b
apply_lora_to_mlp: True
lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
lora_rank: 16
lora_alpha: 32
use_dora: True
quantize_base: True
Note
Under the hood, we’ve enabled DoRA by adding the DoRALinear
module, which we swap
out for LoRALinear
when use_dora=True
.