• Docs >
  • Memory Optimization Overview
Shortcuts

Memory Optimization Overview

Author: Salman Mohammadi

torchtune comes with a host of plug-and-play memory optimization components which give you lots of flexibility to tune our recipes to your hardware. This page provides a brief glossary of these components and how you might use them. To make things easy, we’ve summarized these components in the following table:

Memory optimization components

Component

When to use?

Model Precision

You’ll usually want to leave this as its default bfloat16. If you’re struggling with training stability or accuracy due to precision, fp32 may help, but will significantly increase memory usage and decrease training speed.

Activation Checkpointing

Use when you’re memory constrained and need to handle larger batch sizes or longer context lengths. Be aware that it may slow down training speed.

Gradient Accumulation

Helpful when memory-constrained to simulate larger batch sizes. Often preferable to activation checkpointing for better training speed.

Lower Precision Optimizers

When you need to further reduce memory usage beyond using bf16 by reducing the precision in the optimizer states. Note that lower precision optimizers may reduce training stability/accuracy.

Fusing Optimizer Step into Backward Pass

Helps reduce memory usage when using stateful optimizers, particularly when full-finetuning large models with high gradient memory usage. This is not compatible with gradient_accumulation_steps, so training may slow down due to reduced model throughput.

Low Rank Adaptation (LoRA)

When you want to significantly reduce the number of trainable parameters, saving gradient and optimizer memory during training, and significantly speeding up training.

Quantized Low Rank Adaptation (QLoRA)

When you need even more memory savings than LoRA, at the potential cost of some training speed. Useful for very large models or limited hardware.

Note

In its current state, this tutorial is focused on single-device optimizations. Check in soon as we update this page for the latest memory optimization features for distributed fine-tuning.

Model Precision

What’s going on here?

We use the term “precision” to refer to the underlying data type used to represent the model and optimizer parameters. We support two data types in torchtune:

Note

We recommend diving into Sebastian Raschka’s blogpost on mixed-precision techniques for a deeper understanding of concepts around precision and data formats.

  • fp32, commonly referred to as “full-precision”, uses 4 bytes per model and optimizer parameter.

  • bfloat16, referred to as “half-precision”, uses 2 bytes per model and optimizer parameter - effectively half the memory of fp32, and also improves training speed. Generally, if your hardware supports training with bfloat16, we recommend using it - this is the default setting for our recipes.

Note

Another common paradigm is “mixed-precision” training: where model weights are in bfloat16 (or fp16), and optimizer states are in fp32. Currently, we don’t support mixed-precision training in torchtune.

Sounds great! How do I use it?

Simply use the dtype flag or config entry in all our recipes! For example, to use half-precision training in bf16, set dtype=bf16.

Activation Checkpointing

What’s going on here?

The relevant section in the PyTorch documentation explains this concept well. To quote:

Activation checkpointing is a technique that trades compute for memory. Instead of keeping tensors needed for backward alive until they are used in gradient computation during backward, forward computation in checkpointed regions omits saving tensors for backward and recomputes them during the backward pass.

This setting is helpful for when you’re memory-constrained, especially due to larger batch sizes or longer context lengths. However, these savings in memory come at the cost of training speed (i.e. tokens-per-second), and in most cases training can slow down quite a bit as a result of this activation recomputation.

Sounds great! How do I use it?

To enable activation checkpointing, use the enable_activation_checkpointing config entry or flag in any of our recipes, e.g. enable_activation_checkpointing=True.

Gradient Accumulation

What’s going on here?

Gradient accumulation allows you to simulate large batch sizes by accumulating gradients over several batches before updating model parameters using the optimizer. Concretely, the total number of samples used for a gradient update is when using gradient accumulation is:

total_batch_size = batch_size * gradient_accumulation_steps

For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32.

Note

For other components in torchtune which use “steps”, such as metric logging, or learning rate schedulers, a “step” is counted as a single update to model parameters, rather than a single model forward pass with the data. Suppose gradient_accumulation_steps = 4 and log_every_n_steps = 10. Metrics would be logged every 10 global steps, which translates to every 40 model forward passes. For this reason, metric logging will appear less frequently when training with gradient accumulation, and progress bars may update more slowly.

If you’re using one of our distributed recipes, simply multiply by the number of devices:

total_batch_size = batch_size * gradient_accumulation_steps * num_devices

Gradient accumulation is especially useful when you are memory constrained. In this case, accumulating gradients might give you better training speed than enabling activation checkpointing, since activation checkpointing reduces memory consumption at the cost of repeated computations.

Sounds great! How do I use it?

All of our finetuning recipes support simulating larger batch sizes by accumulating gradients. Just set the gradient_accumulation_steps flag or config entry.

Note

Gradient accumulation should always be set to 1 when fusing the optimizer step into the backward pass.

Lower Precision Optimizers

What’s going on here?

In addition to reducing model and optimizer precision during training, we can further reduce precision in our optimizer states. All of our single-device fine-tuning recipes support lower-precision optimizers from the bitsandbytes library - a good place to start might be the AdamW8bit and PagedAdamW8bit optimizers, which we’ve tested our recipes with.

Sounds great! How do I use it?

To use this in your recipes, make sure you have installed bitsandbytes (pip install bitsandbytes). Then, enable a low precision optimizer using the torchtune CLI:

tune run <RECIPE> --config <CONFIG> \
optimizer=bitsandbytes.optim.PagedAdamW

or by directly modifying a config file:

optimizer:
  _component_: bitsandbytes.optim.PagedAdamW
  lr: 2e-5

Fusing Optimizer Step into Backward Pass

What’s going on here?

Stateful optimizers (e.g. optimizers which use momentum) are the default in modern deep learning due to their stable convergence properties. However, maintaining a state of gradient statistics comes at the cost of additional memory usage. An immediate alternative might be to turn to stateless optimizers such as stochastic gradient descent without momentum, which don’t require any additional memory usage, but will likely result in worse convergence during training.

Can we find a middle ground here? Let’s consider a technique which enables the use of “stateful” optimizers such as AdamW without the memory overhead of gradient statistics, and without sacrificing their desirable convergence properties. How is this possible, you might ask? By completely removing the buffer of gradients which are stored by the optimizer during its step().

To understand how this works, we encourage you to read through the relevant PyTorch tutorial on this concept: How to save memory by fusing the optimizer step into the backward pass.

Sounds great! How do I use it?

In torchtune, you can enable this feature using the optimizer_in_bwd flag, which is currently only supported in our single-device full finetune recipe. This feature works best when gradient memory is particularly large; e.g. when using a stateful optimizer with a model with a lot of parameters, and when you don’t need to use gradient accumulation.

Parameter Efficient Fine-Tuning (PEFT)

Low Rank Adaptation (LoRA)

What’s going on here?

You can read our tutorial on finetuning Llama2 with LoRA to understand how LoRA works, and how to use it. Simply stated, LoRA greatly reduces the number of trainable parameters, thus saving significant gradient and optimizer memory during training.

Sounds great! How do I use it?

You can finetune using any of our recipes with the lora_ prefix, e.g. lora_finetune_single_device. These recipes utilize LoRA-enabled model builders, which we support for all our models, and also use the lora_ prefix, e.g. the torchtune.models.llama3.llama3() model has a corresponding torchtune.models.llama3.lora_llama3(). We aim to provide a comprehensive set of configurations to allow you to get started with training with LoRA quickly, just specify any config with _lora in its name, e.g:

tune run lora_finetune_single_device --config llama3/8B_lora_single_device

There are two sets of parameters to customize LoRA to suit your needs. Firstly, the parameters which control which linear layers LoRA should be applied to in the model:

  • lora_attn_modules: List[str] accepts a list of strings specifying which layers of the model to apply LoRA to:

    • q_proj applies LoRA to the query projection layer.

    • k_proj applies LoRA to the key projection layer.

    • v_proj applies LoRA to the value projection layer.

    • output_proj applies LoRA to the attention output projection layer.

    Whilst adding more layers to be fine-tuned may improve model accuracy, this will come at the cost of increased memory usage and reduced training speed.

  • apply_lora_to_mlp: Bool applies LoRA to the MLP in each transformer layer.

  • apply_lora_to_output: Bool applies LoRA to the model’s final output projection. This is usually a projection to vocabulary space (e.g. in language models), but other modelling tasks may have different projections - classifier models will project to the number of classes, for example

Note

Models which use tied embeddings (such as Gemma and Qwen2 1.5B and 0.5B) for the final output projection do not support apply_lora_to_output.

These are all specified under the model flag or config entry, i.e:

tune run lora_finetune_single_device --config llama3/8B_lora_single_device  \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"]
model:
  apply_lora_to_mlp: True
  model.lora_attn_modules: ["q_proj", "k_proj", "v_proj"]

Secondly, parameters which control the scale of the impact of LoRA on the model:

  • lora_rank: int affects the scale of the LoRA decomposition, where lora_rank << in_dim and lora_rank << out_dim - the dimensions of an arbitrary linear layer in the model. Concretely, lora_rank reduces the number of gradients stored in a linear fashion from in_dim * out_dim to lora_rank * (in_dim + out_dim). Typically, we have lora_rank in [8, 128].

  • lora_alpha: float affects the magnitude of the LoRA updates. A larger alpha results in larger updates to the base model weights , potentially at the cost of training stability, conversely, smaller alpha can stabilize training at the cost of slower learning. We provide default settings for these parameters which we’ve tested with all of our models, but we encourage you to adjust them to your specific use case. Typically, one jointly changes lora_rank and lora_alpha together, where lora_alpha ~= 2*lora_rank.

  • lora_dropout introduces dropout in the LoRA layers to help regularize training. We default to 0.0 for all of our models.

As above, these parameters are also specified under the model flag or config entry.

Note

To get a deeper sense of how LoRA parameters affect memory usage during training, see the relevant section in our Llama2 LoRA tutorial.

Quantized Low Rank Adaptation (QLoRA)

What’s going on here?

QLoRA is an enhancement on top of LoRA that maintains the frozen model parameters from LoRA in 4-bit quantized precision, thereby reducing memory usage. This is enabled through a novel 4-bit NormalFloat (NF4) data type proposed by the authors, which allows for 4-8x less parameter memory usage whilst retaining model accuracy. You can read our tutorial on finetuning Llama2 with QLoRA for a deeper understanding of how it works.

When considering using QLoRA to reduce memory usage, it’s worth noting that QLoRA prevents accuracy degradation during quantization by up-casting quantized parameters to the original higher precision datatype during model forward passes - this up-casting may incur penalties to training speed. The relevant section in our QLoRA tutorial demonstrates the usage of torch.compile to address this by speeding up training.

Sounds great! How do I use it?

You can finetune using QLoRA with any of our LoRA recipes, i.e. recipes with the lora_ prefix, e.g. lora_finetune_single_device. These recipes utilize QLoRA-enabled model builders, which we support for all our models, and also use the qlora_ prefix, e.g. the torchtune.models.llama3.llama3_8b() model has a corresponding torchtune.models.llama3.qlora_llama3_8b(). We aim to provide a comprehensive set of configurations to allow you to get started with training with QLoRA quickly, just specify any config with _qlora in its name, e.g:

tune run lora_finetune_single_device --config llama3/8B_qlora_single_device

All the rest of the LoRA parameters remain the same for QLoRA - check out the section above on LoRA to see how to configure.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources