• Tutorials >
  • How to save memory by fusing the optimizer step into the backward pass
Shortcuts

How to save memory by fusing the optimizer step into the backward pass

Hello there! This tutorial aims to showcase one way of reducing the memory footprint of a training loop by reducing the memory taken by the gradients. Say you have a model and you’re interested in ways to optimize memory to avoid Out of Memory (OOM) errors or simply to ooze more out of your GPU. Well, you _might_ be in luck (if gradients take up a portion of your memory and you do not need to do gradient accumulation). We will explore the following:

  1. What takes up memory during your training or finetuning loop,

  2. How to capture and visualize memory snapshots to determine the bottleneck,

  3. The new Tensor.register_post_accumulate_grad_hook(hook) API, and finally,

  4. How everything fits together in 10 lines to achieve memory savings.

To run this tutorial, you will need:

  • PyTorch 2.1.0 or newer with torchvision

  • 1 CUDA GPU if you’d like to run the memory visualizations locally. Otherwise, this technique would benefit similarly on any device.

Let us start by importing the required modules and models. We will use a vision transformer model from torchvision, but feel free to substitute with your own model. We will also use torch.optim.Adam as our optimizer, but, again, feel free to substitute with your own optimizer.

import torch
from torchvision import models
from pickle import dump

model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
Downloading: "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/vit_l_16-852ce7e3.pth

  0%|          | 0.00/1.13G [00:00<?, ?B/s]
  1%|1         | 14.9M/1.13G [00:00<00:18, 63.4MB/s]
  2%|1         | 21.0M/1.13G [00:00<00:23, 51.6MB/s]
  3%|2         | 32.9M/1.13G [00:00<00:22, 53.4MB/s]
  4%|4         | 49.2M/1.13G [00:00<00:17, 66.6MB/s]
  6%|5         | 64.9M/1.13G [00:00<00:14, 81.7MB/s]
  6%|6         | 73.2M/1.13G [00:01<00:15, 74.8MB/s]
  7%|6         | 80.6M/1.13G [00:01<00:23, 48.6MB/s]
  7%|7         | 86.4M/1.13G [00:01<00:23, 47.7MB/s]
  8%|8         | 98.4M/1.13G [00:01<00:18, 58.7MB/s]
 10%|9         | 113M/1.13G [00:02<00:19, 57.6MB/s]
 10%|#         | 119M/1.13G [00:02<00:20, 52.8MB/s]
 11%|#1        | 131M/1.13G [00:02<00:18, 58.5MB/s]
 12%|#2        | 140M/1.13G [00:02<00:17, 59.7MB/s]
 13%|#2        | 146M/1.13G [00:02<00:20, 52.4MB/s]
 13%|#3        | 151M/1.13G [00:02<00:23, 46.0MB/s]
 14%|#4        | 164M/1.13G [00:03<00:20, 52.0MB/s]
 15%|#5        | 179M/1.13G [00:03<00:22, 45.1MB/s]
 16%|#5        | 183M/1.13G [00:03<00:26, 39.0MB/s]
 17%|#6        | 197M/1.13G [00:03<00:19, 50.8MB/s]
 18%|#8        | 212M/1.13G [00:03<00:15, 62.5MB/s]
 19%|#8        | 218M/1.13G [00:04<00:17, 55.2MB/s]
 19%|#9        | 224M/1.13G [00:04<00:19, 51.5MB/s]
 20%|#9        | 229M/1.13G [00:04<00:21, 45.0MB/s]
 21%|##1       | 244M/1.13G [00:04<00:19, 50.5MB/s]
 21%|##1       | 249M/1.13G [00:04<00:20, 46.1MB/s]
 22%|##2       | 261M/1.13G [00:05<00:16, 56.9MB/s]
 23%|##2       | 266M/1.13G [00:05<00:18, 52.1MB/s]
 24%|##3       | 279M/1.13G [00:05<00:16, 56.2MB/s]
 25%|##5       | 294M/1.13G [00:05<00:13, 68.5MB/s]
 26%|##5       | 301M/1.13G [00:05<00:14, 63.6MB/s]
 27%|##6       | 310M/1.13G [00:05<00:13, 65.2MB/s]
 27%|##7       | 316M/1.13G [00:05<00:14, 60.4MB/s]
 28%|##8       | 328M/1.13G [00:06<00:13, 65.3MB/s]
 29%|##9       | 339M/1.13G [00:06<00:11, 75.4MB/s]
 30%|##9       | 346M/1.13G [00:06<00:24, 34.8MB/s]
 31%|###1      | 360M/1.13G [00:07<00:18, 45.3MB/s]
 32%|###2      | 376M/1.13G [00:07<00:13, 58.8MB/s]
 33%|###3      | 384M/1.13G [00:07<00:15, 54.1MB/s]
 34%|###3      | 393M/1.13G [00:07<00:13, 59.2MB/s]
 35%|###5      | 408M/1.13G [00:07<00:10, 73.6MB/s]
 36%|###5      | 416M/1.13G [00:07<00:11, 68.2MB/s]
 37%|###6      | 426M/1.13G [00:07<00:12, 61.3MB/s]
 37%|###7      | 433M/1.13G [00:08<00:12, 62.7MB/s]
 38%|###7      | 441M/1.13G [00:08<00:13, 57.7MB/s]
 38%|###8      | 447M/1.13G [00:08<00:13, 53.9MB/s]
 40%|###9      | 459M/1.13G [00:08<00:13, 53.2MB/s]
 41%|####      | 474M/1.13G [00:08<00:10, 66.2MB/s]
 41%|####1     | 480M/1.13G [00:08<00:11, 59.9MB/s]
 42%|####2     | 490M/1.13G [00:09<00:10, 68.0MB/s]
 43%|####2     | 497M/1.13G [00:09<00:11, 61.8MB/s]
 44%|####3     | 506M/1.13G [00:09<00:10, 64.8MB/s]
 44%|####4     | 513M/1.13G [00:09<00:11, 58.6MB/s]
 45%|####5     | 524M/1.13G [00:09<00:09, 71.6MB/s]
 46%|####5     | 532M/1.13G [00:09<00:10, 62.4MB/s]
 46%|####6     | 540M/1.13G [00:09<00:09, 67.6MB/s]
 47%|####7     | 547M/1.13G [00:10<00:10, 61.4MB/s]
 48%|####7     | 556M/1.13G [00:10<00:09, 67.1MB/s]
 49%|####8     | 563M/1.13G [00:10<00:10, 62.1MB/s]
 49%|####9     | 572M/1.13G [00:10<00:09, 67.8MB/s]
 50%|####9     | 579M/1.13G [00:10<00:11, 51.2MB/s]
 51%|#####     | 590M/1.13G [00:10<00:10, 56.9MB/s]
 52%|#####1    | 602M/1.13G [00:10<00:08, 70.4MB/s]
 52%|#####2    | 609M/1.13G [00:11<00:09, 64.1MB/s]
 53%|#####3    | 619M/1.13G [00:11<00:07, 72.3MB/s]
 54%|#####4    | 628M/1.13G [00:11<00:07, 77.9MB/s]
 55%|#####4    | 636M/1.13G [00:11<00:07, 75.5MB/s]
 55%|#####5    | 644M/1.13G [00:11<00:10, 53.2MB/s]
 56%|#####6    | 655M/1.13G [00:11<00:10, 52.3MB/s]
 58%|#####7    | 672M/1.13G [00:12<00:07, 66.7MB/s]
 59%|#####9    | 687M/1.13G [00:12<00:05, 84.7MB/s]
 60%|######    | 697M/1.13G [00:12<00:06, 72.1MB/s]
 61%|######    | 705M/1.13G [00:12<00:10, 44.1MB/s]
 62%|######2   | 720M/1.13G [00:12<00:07, 58.1MB/s]
 63%|######2   | 728M/1.13G [00:13<00:07, 59.6MB/s]
 64%|######3   | 737M/1.13G [00:13<00:07, 62.9MB/s]
 65%|######4   | 754M/1.13G [00:13<00:06, 69.8MB/s]
 66%|######5   | 761M/1.13G [00:13<00:05, 71.0MB/s]
 66%|######6   | 769M/1.13G [00:13<00:06, 63.6MB/s]
 67%|######6   | 776M/1.13G [00:13<00:06, 58.3MB/s]
 68%|######7   | 786M/1.13G [00:13<00:06, 62.1MB/s]
 68%|######8   | 792M/1.13G [00:14<00:06, 63.4MB/s]
 69%|######9   | 803M/1.13G [00:14<00:07, 52.9MB/s]
 70%|#######   | 816M/1.13G [00:14<00:05, 69.5MB/s]
 71%|#######   | 824M/1.13G [00:14<00:07, 44.5MB/s]
 71%|#######1  | 830M/1.13G [00:15<00:08, 43.3MB/s]
 72%|#######1  | 836M/1.13G [00:15<00:08, 38.6MB/s]
 73%|#######3  | 851M/1.13G [00:15<00:05, 55.4MB/s]
 74%|#######3  | 858M/1.13G [00:15<00:06, 48.2MB/s]
 75%|#######4  | 868M/1.13G [00:15<00:05, 51.3MB/s]
 76%|#######6  | 884M/1.13G [00:15<00:04, 64.2MB/s]
 77%|#######6  | 891M/1.13G [00:16<00:04, 58.9MB/s]
 77%|#######7  | 900M/1.13G [00:16<00:05, 54.5MB/s]
 78%|#######7  | 905M/1.13G [00:16<00:06, 43.0MB/s]
 79%|#######8  | 916M/1.13G [00:16<00:05, 45.8MB/s]
 79%|#######9  | 921M/1.13G [00:16<00:06, 41.7MB/s]
 80%|########  | 932M/1.13G [00:17<00:04, 51.8MB/s]
 81%|########  | 938M/1.13G [00:17<00:05, 44.3MB/s]
 82%|########1 | 949M/1.13G [00:17<00:03, 56.9MB/s]
 82%|########2 | 955M/1.13G [00:17<00:04, 53.6MB/s]
 83%|########3 | 967M/1.13G [00:17<00:03, 57.4MB/s]
 85%|########4 | 982M/1.13G [00:17<00:02, 77.6MB/s]
 85%|########5 | 991M/1.13G [00:18<00:02, 61.2MB/s]
 86%|########6 | 0.98G/1.13G [00:18<00:02, 64.6MB/s]
 87%|########7 | 0.99G/1.13G [00:18<00:02, 74.0MB/s]
 88%|########8 | 1.00G/1.13G [00:18<00:02, 69.0MB/s]
 89%|########8 | 1.01G/1.13G [00:18<00:02, 64.0MB/s]
 89%|########9 | 1.01G/1.13G [00:18<00:02, 56.9MB/s]
 90%|######### | 1.02G/1.13G [00:19<00:02, 54.8MB/s]
 91%|######### | 1.03G/1.13G [00:19<00:02, 49.7MB/s]
 92%|#########1| 1.04G/1.13G [00:19<00:01, 57.4MB/s]
 93%|#########3| 1.05G/1.13G [00:19<00:01, 66.6MB/s]
 94%|#########3| 1.06G/1.13G [00:19<00:01, 58.3MB/s]
 94%|#########4| 1.07G/1.13G [00:19<00:01, 59.4MB/s]
 95%|#########4| 1.08G/1.13G [00:20<00:01, 51.7MB/s]
 95%|#########5| 1.08G/1.13G [00:20<00:01, 49.5MB/s]
 96%|#########5| 1.09G/1.13G [00:20<00:01, 45.0MB/s]
 97%|#########6| 1.09G/1.13G [00:20<00:00, 50.9MB/s]
 97%|#########7| 1.10G/1.13G [00:20<00:00, 47.5MB/s]
 99%|#########8| 1.12G/1.13G [00:20<00:00, 75.9MB/s]
 99%|#########9| 1.13G/1.13G [00:20<00:00, 60.4MB/s]
100%|##########| 1.13G/1.13G [00:21<00:00, 57.9MB/s]

Now let’s define our typical training loop. You should use real images when training, but for the purposes of this tutorial, we are passing in fake inputs and not worrying about loading any actual data.

IMAGE_SIZE = 224

def train(model, optimizer):
  # create our fake image input: tensor shape is batch_size, channels, height, width
  fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()

  # call our forward and backward
  loss = model.forward(fake_image)
  loss.sum().backward()

  # optimizer update
  optimizer.step()
  optimizer.zero_grad()

Memory usage during training

We are about to look at some memory snapshots, so we should be prepared to analyze them properly. Typically, training memory consists of:

  • Model parameters (size P)

  • Activations that are saved for the backward pass (size A)

  • Gradients, which are the same size as the model parameters, so size G = P.

  • Optimizer state, which is proportional to the size of the parameters. In this case, the state for Adam requires 2x the model parameters, so size O = 2P.

  • Intermediate tensors, which are allocated throughout the compute. We will not worry about them for now as they are usually small and ephemeral.

Capturing and visualizing memory snapshots

Let’s get us a memory snapshot! As your code runs, consider what you may expect the CUDA memory timeline to look like.

# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')

# train 3 steps
for _ in range(3):
  train(model, optimizer)

# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot.pickle", "wb") as f:
    dump(s, f)

# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)

Now open up the snapshot in the CUDA Memory Visualizer at https://pytorch.org/memory_viz by dragging and dropping the snapshot.pickle file. Does the memory timeline match your expectations?

snapshot.png loaded into CUDA Memory Visualizer

The model parameters have already been loaded in memory before the training step, so we see a chunk of memory devoted to the weights right off the bat. As we start our forward pass, memory is allocated gradually for the activations, or the tensors we are saving to be able to compute gradients in the backward pass. Once we start the backward pass, the activations are gradually freed while memory of the gradients starts building up.

Lastly, as the optimizer kicks in, its state will be lazily initialized, so we should see the optimizer state memory gradually increase during the optimizer step of the first training loop only. In future loops, the optimizer memory will remain and be updated in-place. The memory for the gradients is then freed accordingly at the end of every training loop when zero_grad is called.

Where is the memory bottleneck in this training loop? Or, in other words, where is the peak memory?

The peak memory usage is during the optimizer step! Note the memory then consists of ~1.2GB of parameters, ~1.2GB of gradients, and ~2.4GB=2*1.2GB of the optimizer state as expected. The last ~1.2GB comes from Adam optimizer requiring memory for intermediates, totaling to ~6GB of peak memory. Technically, you can remove the need for the last 1.2GB for optimizer intermediates if you set Adam(model.parameters(), foreach=False) which would trade off runtime for memory. If switching off the foreach runtime optimization is sufficient in memory savings for you, nice, but please read on if you’re curious how this tutorial can help you do better! With the technique we will soon introduce, we will reduce peak memory by removing the need for the ~1.2GB of gradients memory as well as optimizer intermediates memory. Now, what would you expect the new peak memory to be? The answer will be revealed in the next snapshot.

DISCLAIMER: This technique is not for all

Before we get too excited, we have to consider whether this technique is applicable for your use case. This is NOT a silver bullet! The technique of fusing the optimizer step into the backward only targets reducing gradient memory (and as a side effect also optimizer intermediates memory). Thus, the more sizable the memory taken up by the gradients, the more tantamount the memory reduction. In our example above, the gradients eat up 20% of the memory pie, which is quite sizable!

This may not be the case for you, for example, if your weights are already tiny, (say, due to applying LoRa,) then the gradients do not take much space in your training loop and the wins are way less exciting. In that case, you should first try other techniques like activations checkpointing, distributed training, quantization, or reducing the batch size. Then, when the gradients are part of the bottleneck again, come back to this tutorial!

Still here? Cool, let’s introduce our new register_post_accumulate_grad_hook(hook) API on Tensor.

Tensor.register_post_accumulate_grad_hook(hook) API and our technique

Our technique relies on not having to save the gradients during backward(). Instead, once a gradient has been accumulated, we will immediately apply the optimizer to the corresponding parameter and drop that gradient entirely! This removes the need for holding onto a big buffer of gradients until the optimizer step.

So how can we unlock the behavior of applying the optimizer more eagerly? In our 2.1 release, we’ve added a new API torch.Tensor.register_post_accumulate_grad_hook() that would allow us to add a hook onto a Tensor once its .grad field has been accumulated. We will encapsulate the optimizer step into this hook. How?

How everything fits together in 10 lines

Remember our model and optimizer setup from the beginning? I’ll leave them commented out below so we don’t spend resources rerunning the code.

model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
# for every parameter so we could reference them in our hook.
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}

# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
def optimizer_hook(parameter) -> None:
  optimizer_dict[parameter].step()
  optimizer_dict[parameter].zero_grad()

# Register the hook onto every parameter
for p in model.parameters():
   p.register_post_accumulate_grad_hook(optimizer_hook)

# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
  # create our fake image input: tensor shape is batch_size, channels, height, width
  fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()

  # call our forward and backward
  loss = model.forward(fake_image)
  loss.sum().backward()

  # optimizer update --> no longer needed!
  # optimizer.step()
  # optimizer.zero_grad()

That took about 10 lines of changes in our sample model, which is neat. However, for real models, it could be a fairly intrusive change to switch out the optimizer for an optimizer dictionary, especially for those who use ``LRScheduler``s or manipulate optimizer configuration throughout the training epochs. Working out this API with those changes will be more involved and will likely require moving more configuration into global state but should not be impossible. That said, a next step for PyTorch is to make this API easier to adopt with LRSchedulers and other features you are already used to.

But let me get back to convincing you that this technique is worth it. We will consult our friend, the memory snapshot.

# delete optimizer memory from before to get a clean slate for the next
# memory snapshot
del optimizer

# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')

# train 3 steps. note that we no longer pass the optimizer into train()
for _ in range(3):
  train(model)

# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot-opt-in-bwd.pickle", "wb") as f:
    dump(s, f)

# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)

Yes, take some time to drag your snapshot into the CUDA Memory Visualizer.

snapshot.png loaded into CUDA Memory Visualizer
Several major observations:
  1. There is no more optimizer step! Right…we fused that into the backward.

  2. Likewise, the backward drags longer and there are more random allocations for intermediates. This is expected, as the optimizer step requires intermediates.

  3. Most importantly! The peak memory is lower! It is now ~4GB (which I hope maps closely to your earlier expectation).

Note that there is no longer any big chunk of memory allocated for the gradients compared to before, accounting for ~1.2GB of memory savings. Instead, we’ve freed each gradient very quickly after they’ve been computed by moving the optimizer step as far ahead as we can. Woohoo! By the way, the other ~1.2GB of memory savings comes from breaking apart the optimizer into per-parameter optimizers, so the intermediates have proportionally shrunk. This detail is less important than the gradient memory savings, as you can get optimizer intermediates savings from just turning foreach=False without this technique.

You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start of the backward step, when we still have activations in memory, where before, the peak was during the optimizer step when the activations had been freed. The ~0.4GB difference accounting for ~4.0GB - ~3.6GB is thus due to the activations memory. One can then imagine that this technique can be coupled with activations checkpointing for more memory wins.

Conclusion

In this tutorial, we learned about the memory saving technique of fusing the optimizer into the backward step through the new Tensor.register_post_accumulate_grad_hook() API and when to apply this technique (when gradients memory is significant). Along the way, we also learned about memory snapshots, which are generally useful in memory optimization.

Total running time of the script: ( 0 minutes 30.609 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources