• Docs >
  • Memory Optimization Overview
Shortcuts

Memory Optimization Overview

Author: Salman Mohammadi

torchtune comes with a host of plug-and-play memory optimization components which give you lots of flexibility to tune our recipes to your hardware. This page provides a brief glossary of these components and how you might use them. To make things easy, we’ve summarized these components in the following table:

Memory optimization components

Component

When to use?

Model Precision

You’ll usually want to leave this as its default bfloat16. It uses 2 bytes per model parameter instead of 4 bytes when using float32.

Activation Checkpointing

Use when you’re memory constrained and want to use a larger model, batch size or context length. Be aware that it will slow down training speed.

Activation Offloading

Similar to activation checkpointing, this can be used when memory constrained, but may decrease training speed. This should be used alongside activation checkpointing.

Gradient Accumulation

Helpful when memory-constrained to simulate larger batch sizes. Not compatible with optimizer in backward. Use it when you can already fit at least one sample without OOMing, but not enough of them.

Lower Precision Optimizers

Use when you want to reduce the size of the optimizer state. This is relevant when training large models and using optimizers with momentum, like Adam. Note that lower precision optimizers may reduce training stability/accuracy.

Fusing Optimizer Step into Backward Pass

Use it when you have large gradients and can fit a large enough batch size, since this is not compatible with gradient_accumulation_steps.

Offloading Optimizer/Gradient states to CPU

Offloads optimizer states and (optionally) gradients to CPU, and performs optimizer steps on CPU. This can be used to significantly reduce GPU memory usage at the cost of CPU RAM and training speed. Prioritize using it only if the other techniques are not enough.

Low Rank Adaptation (LoRA)

When you want to significantly reduce the number of trainable parameters, saving gradient and optimizer memory during training, and significantly speeding up training. This may reduce training accuracy

Quantized Low Rank Adaptation (QLoRA)

When you are training a large model, since quantization will save 1.5 bytes * (# of model parameters), at the potential cost of some training speed and accuracy.

Weight-Decomposed Low-Rank Adaptation (DoRA)

a variant of LoRA that may improve model performance at the cost of slightly more memory.

Note

In its current state, this tutorial is focused on single-device optimizations. Check in soon as we update this page for the latest memory optimization features for distributed fine-tuning.

Model Precision

What’s going on here?

We use the term “precision” to refer to the underlying data type used to represent the model and optimizer parameters. We support two data types in torchtune:

Note

We recommend diving into Sebastian Raschka’s blogpost on mixed-precision techniques for a deeper understanding of concepts around precision and data formats.

  • fp32, commonly referred to as “full-precision”, uses 4 bytes per model and optimizer parameter.

  • bfloat16, referred to as “half-precision”, uses 2 bytes per model and optimizer parameter - effectively half the memory of fp32, and also improves training speed. Generally, if your hardware supports training with bfloat16, we recommend using it - this is the default setting for our recipes.

Note

Another common paradigm is “mixed-precision” training: where model weights are in bfloat16 (or fp16), and optimizer states are in fp32. Currently, we don’t support mixed-precision training in torchtune.

Sounds great! How do I use it?

Simply use the dtype flag or config entry in all our recipes! For example, to use half-precision training in bf16, set dtype=bf16.

Activation Checkpointing

What’s going on here?

The relevant section in the PyTorch documentation explains this concept well. To quote:

Activation checkpointing is a technique that trades compute for memory. Instead of keeping tensors needed for backward alive until they are used in gradient computation during backward, forward computation in checkpointed regions omits saving tensors for backward and recomputes them during the backward pass.

This setting is helpful for when you’re memory-constrained, especially due to larger batch sizes or longer context lengths. However, these savings in memory come at the cost of training speed (i.e. tokens-per-second), and in most cases training can slow down quite a bit as a result of this activation recomputation.

Sounds great! How do I use it?

To enable activation checkpointing, use enable_activation_checkpointing=True.

Activation Offloading

What’s going on here?

You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing them back when needed in the backward pass.

See PyTorch autograd hook tutorial for more details about how this is implemented through torch.autograd.graph.saved_tensors_hooks().

This setting is especially helpful for larger batch sizes, or longer context lengths when you’re memory constrained. While of course it takes runtime and resources to move Tensors from GPU to CPU and back, the implementation in torchtune uses multiple CUDA streams (when available) in order to overlap the extra communication with the computation to hide the extra runtime. As the communication workload is variable depending on the number and size of tensors being offloaded, we do not recommend using it unless Activation Checkpointing is also enabled, in which case only the checkpointed tensors will be offloaded.

Sounds great! How do I use it?

To enable activation offloading, use the enable_activation_offloading config entry or flag in our lora finetuning single device recipe, e.g. enable_activation_offloading=True. To allow usage of streams, make sure you are on a torch version equal to or later than PyTorch.

Gradient Accumulation

What’s going on here?

Gradient accumulation allows you to simulate large batch sizes by accumulating gradients over several batches before updating model parameters using the optimizer. Concretely, the total number of samples used for a gradient update is when using gradient accumulation is:

total_batch_size = batch_size * gradient_accumulation_steps

For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32.

Note

For other components in torchtune which use “steps”, such as metric logging, or learning rate schedulers, a “step” is counted as a single update to model parameters, rather than a single model forward pass with the data. Suppose gradient_accumulation_steps = 4 and log_every_n_steps = 10. Metrics would be logged every 10 global steps, which translates to every 40 model forward passes. For this reason, metric logging will appear less frequently when training with gradient accumulation, and progress bars may update more slowly.

If you’re using one of our distributed recipes, simply multiply by the number of devices:

total_batch_size = batch_size * gradient_accumulation_steps * num_devices

Gradient accumulation is especially useful when you can fit at least one sample in your GPU. In this case, artificially increasing the batch by accumulating gradients might give you faster training speeds than using other memory optimization techniques that trade-off memory for speed, like activation checkpointing.

Sounds great! How do I use it?

All of our finetuning recipes support simulating larger batch sizes by accumulating gradients. Just set the gradient_accumulation_steps flag or config entry.

Note

Gradient accumulation should always be set to 1 when fusing the optimizer step into the backward pass.

Optimizers

Lower Precision Optimizers

What’s going on here?

In addition to reducing model and optimizer precision during training, we can further reduce precision in our optimizer states. All of our recipes support lower-precision optimizers from the torchao library. For single device recipes, we also support bitsandbytes.

A good place to start might be the torchao.prototype.low_bit_optim.AdamW8bit and bitsandbytes.optim.PagedAdamW8bit optimizers. Both reduce memory by quantizing the optimizer state dict. Paged optimizers will also offload to CPU if there isn’t enough GPU memory available. In practice, you can expect higher memory savings from bnb’s PagedAdamW8bit but higher training speed from torchao’s AdamW8bit.

Sounds great! How do I use it?

To use this in your recipes, make sure you have installed torchao (pip install torchao) or bitsandbytes (pip install bitsandbytes). Then, enable a low precision optimizer using the torchtune CLI:

tune run <RECIPE> --config <CONFIG> \
optimizer=torchao.prototype.low_bit_optim.AdamW8bit
tune run <RECIPE> --config <CONFIG> \
optimizer=bitsandbytes.optim.PagedAdamW8bit

or by directly modifying a config file:

optimizer:
  _component_: bitsandbytes.optim.PagedAdamW8bit
  lr: 2e-5

Fusing Optimizer Step into Backward Pass

What’s going on here?

Stateful optimizers (e.g. optimizers which use momentum) are the default in modern deep learning due to their stable convergence properties. However, maintaining a state of gradient statistics comes at the cost of additional memory usage. An immediate alternative might be to turn to stateless optimizers such as stochastic gradient descent without momentum, which don’t require any additional memory usage, but will likely result in worse convergence during training.

Can we find a middle ground here? Let’s consider a technique which enables the use of “stateful” optimizers such as AdamW without the memory overhead of gradient statistics, and without sacrificing their desirable convergence properties. How is this possible, you might ask? By completely removing the buffer of gradients which are stored by the optimizer during its step().

To understand how this works, we encourage you to read through the relevant PyTorch tutorial on this concept: How to save memory by fusing the optimizer step into the backward pass.

Sounds great! How do I use it?

In torchtune, you can enable this feature using the optimizer_in_bwd flag. This feature works best when using a stateful optimizer with a model with a lot of parameters, and when you don’t need to use gradient accumulation. You won’t see meaningful impact when finetuning LoRA recipes, since in this case the number of parameters being updated are small.

Offloading Optimizer/Gradient states to CPU

What’s going on here?

We’ve mentioned above the concept of optimizer states - memory used by the stateful optimizers to maintain a state of gradient statistics, and model gradients - tensors used to store gradients when we perform model backwards passes. We support using CPU offloading in our single-device recipes through the CPUOffloadOptimizer from torchao.

This optimizer can wrap any base optimizer and works by keeping the optimizer states and performing the optimizer step on CPU, thus reducing GPU memory usage by the size of the optimizer states. Additionally, we can also offload gradients to the CPU by using offload_gradients=True.

If finetuning on a single-device, another option is to use the PagedAdamW8bit from bitsandbytes, mentioned above, which will only offload to CPU when there is not enough GPU available.

Sounds great! How do I use it?

To use this optimizer in your recipes, set the optimizer key in your config to torchao.prototype.low_bit_optim.CPUOffloadOptimizer, which will use the torch.optim.AdamW optimizer with fused=True as the base optimizer. For example, to use this optimizer to offload both optimizer states and gradients to CPU:

tune run <RECIPE> --config <CONFIG> \
optimizer=optimizer=torchao.prototype.low_bit_optim.CPUOffloadOptimizer \
optimizer.offload_gradients=True \
lr=4e-5

or by directly modifying a config file:

optimizer:
  _component_: torchao.prototype.low_bit_optim.CPUOffloadOptimizer
  offload_gradients: True
  # additional key-word arguments can be passed to torch.optim.AdamW
  lr: 4e-5

or using it directly in your code, which allows you to change the base optimizer:

from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
from torch.optim import Adam

optimizer = CPUOffloadOptimizer(
    model.parameters(), # your model here
    Adam,
    lr=1e-5,
    fused=True
)

Some helpful hints from the torchao CPUOffloadOptimizer page:

  • The CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) use full bf16 training so that parameters, gradients, and optimizer states are in bf16; and (2) give GPU more work per optimizer step to amortize the offloading time (e.g. larger batch size with activation checkpointing, gradient accumulation).

  • Gradient accumulation should always be set to 1 when offload_gradients=True, as gradients are cleared on GPU every backward pass.

  • This optimizer works by keeping a copy of parameters and pre-allocating gradient memory on CPU. Therefore, expect your RAM usage to increase by 4x model size.

  • This optimizer is only supported for single-device recipes. To use CPU-offloading in distributed recipes, use fsdp_cpu_offload=True instead. See torch.distributed.fsdp.FullyShardedDataParallel for more details and FSDP1 vs FSDP2 to see how they differ.

Parameter Efficient Fine-Tuning (PEFT)

Low Rank Adaptation (LoRA)

What’s going on here?

You can read our tutorial on finetuning Llama2 with LoRA to understand how LoRA works, and how to use it. Simply stated, LoRA greatly reduces the number of trainable parameters, thus saving significant gradient and optimizer memory during training.

Sounds great! How do I use it?

You can finetune using any of our recipes with the lora_ prefix, e.g. lora_finetune_single_device. These recipes utilize LoRA-enabled model builders, which we support for all our models, and also use the lora_ prefix, e.g. the torchtune.models.llama3.llama3() model has a corresponding torchtune.models.llama3.lora_llama3(). We aim to provide a comprehensive set of configurations to allow you to get started with training with LoRA quickly, just specify any config with _lora in its name, e.g:

tune run lora_finetune_single_device --config llama3/8B_lora_single_device

There are two sets of parameters to customize LoRA to suit your needs. Firstly, the parameters which control which linear layers LoRA should be applied to in the model:

  • lora_attn_modules: List[str] accepts a list of strings specifying which layers of the model to apply LoRA to:

    • q_proj applies LoRA to the query projection layer.

    • k_proj applies LoRA to the key projection layer.

    • v_proj applies LoRA to the value projection layer.

    • output_proj applies LoRA to the attention output projection layer.

    Whilst adding more layers to be fine-tuned may improve model accuracy, this will come at the cost of increased memory usage and reduced training speed.

  • apply_lora_to_mlp: Bool applies LoRA to the MLP in each transformer layer.

  • apply_lora_to_output: Bool applies LoRA to the model’s final output projection. This is usually a projection to vocabulary space (e.g. in language models), but other modelling tasks may have different projections - classifier models will project to the number of classes, for example

Note

Models which use tied embeddings (such as Gemma and Qwen2 1.5B and 0.5B) for the final output projection do not support apply_lora_to_output.

These are all specified under the model flag or config entry, i.e:

tune run lora_finetune_single_device --config llama3/8B_lora_single_device  \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj","output_proj"]
model:
  _component_: torchtune.models.llama3.lora_llama3_8b
  apply_lora_to_mlp: True
  model.lora_attn_modules: ["q_proj", "k_proj", "v_proj","output_proj"]

Secondly, parameters which control the scale of the impact of LoRA on the model:

  • lora_rank: int affects the scale of the LoRA decomposition, where lora_rank << in_dim and lora_rank << out_dim - the dimensions of an arbitrary linear layer in the model. Concretely, lora_rank reduces the number of gradients stored in a linear fashion from in_dim * out_dim to lora_rank * (in_dim + out_dim). Typically, we have lora_rank in [8, 256].

  • lora_alpha: float affects the magnitude of the LoRA updates. A larger alpha results in larger updates to the base model weights , potentially at the cost of training stability, conversely, smaller alpha can stabilize training at the cost of slower learning. We provide default settings for these parameters which we’ve tested with all of our models, but we encourage you to adjust them to your specific use case. Typically, one jointly changes lora_rank and lora_alpha together, where lora_alpha ~= 2*lora_rank.

  • lora_dropout introduces dropout in the LoRA layers to help regularize training. We default to 0.0 for all of our models.

As above, these parameters are also specified under the model flag or config entry:

tune run lora_finetune_single_device --config llama3/8B_lora_single_device  \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj","output_proj"] \
model.lora_rank=32 \
model.lora_alpha=64
model:
  _component_: torchtune.models.llama3.lora_llama3_8b
  apply_lora_to_mlp: True
  lora_attn_modules: ["q_proj", "k_proj", "v_proj","output_proj"]
  lora_rank: 32
  lora_alpha: 64

Note

To get a deeper sense of how LoRA parameters affect memory usage during training, see the relevant section in our Llama2 LoRA tutorial.

Quantized Low Rank Adaptation (QLoRA)

What’s going on here?

QLoRA is a memory enhancement on top of LoRA that maintains the frozen model parameters from LoRA in 4-bit quantized precision, thereby reducing memory usage. This is enabled through a novel 4-bit NormalFloat (NF4) data type proposed by the authors, which allows for 4-8x less parameter memory usage whilst retaining model accuracy. You can read our tutorial on finetuning Llama2 with QLoRA for a deeper understanding of how it works.

When considering using QLoRA to reduce memory usage, it’s worth noting that QLoRA is slower than LoRA and may not be worth it if the model you are finetuning is small. In numbers, QLoRA saves roughly 1.5 bytes * (# of model parameters). Also, although QLoRA quantizes the model, it minimizes accuracy degradation by up-casting quantized parameters to the original higher precision datatype during model forward passes - this up-casting may incur penalties to training speed. The relevant section in our QLoRA tutorial demonstrates the usage of torch.compile to address this by speeding up training.

Sounds great! How do I use it?

You can finetune using QLoRA with any of our LoRA recipes, i.e. recipes with the lora_ prefix, e.g. lora_finetune_single_device. These recipes utilize QLoRA-enabled model builders, which we support for all our models, and also use the qlora_ prefix, e.g. the torchtune.models.llama3.llama3_8b() model has a corresponding torchtune.models.llama3.qlora_llama3_8b(). We aim to provide a comprehensive set of configurations to allow you to get started with training with QLoRA quickly, just specify any config with _qlora in its name.

All the rest of the LoRA parameters remain the same for QLoRA - check out the section above on LoRA to see how to configure these parameters.

To configure from the command line:

tune run lora_finetune_single_device --config llama3/8B_qlora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"] \
model.lora_rank=32 \
model.lora_alpha=64

or, by modifying a config:

model:
  _component_: torchtune.models.qlora_llama3_8b
  apply_lora_to_mlp: True
  lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
  lora_rank: 32
  lora_alpha: 64

Weight-Decomposed Low-Rank Adaptation (DoRA)

What’s going on here?

DoRA is another PEFT technique which builds on-top of LoRA by further decomposing the pre-trained weights into two components: magnitude and direction. The magnitude component is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA decomposition and updates the orientation of weights.

DoRA adds a small overhead to LoRA training due to the addition of the magnitude parameter, but it has been shown to improve the performance of LoRA, particularly at low ranks.

Sounds great! How do I use it?

Much like LoRA and QLoRA, you can finetune using DoRA with any of our LoRA recipes. We use the same model builders for LoRA as we do for DoRA, so you can use the lora_ version of any model builder with use_dora=True. For example, to finetune torchtune.models.llama3.llama3_8b() with DoRA, you would use torchtune.models.llama3.lora_llama3_8b() with use_dora=True:

tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.use_dora=True
model:
  _component_: torchtune.models.lora_llama3_8b
  use_dora: True

Since DoRA extends LoRA, the parameters for customizing LoRA are identical. You can also quantize the base model weights like in Quantized Low Rank Adaptation (QLoRA) by using quantize=True to reap even more memory savings!

tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"] \
model.lora_rank=16 \
model.lora_alpha=32 \
model.use_dora=True \
model.quantize_base=True
model:
  _component_: torchtune.models.lora_llama3_8b
  apply_lora_to_mlp: True
  lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
  lora_rank: 16
  lora_alpha: 32
  use_dora: True
  quantize_base: True

Note

Under the hood, we’ve enabled DoRA by adding the DoRALinear module, which we swap out for LoRALinear when use_dora=True.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources