.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/optimizer_step_in_backward_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_optimizer_step_in_backward_tutorial.py: How to save memory by fusing the optimizer step into the backward pass ====================================================================== Hello there! This tutorial aims to showcase one way of reducing the memory footprint of a training loop by reducing the memory taken by the *gradients*. Say you have a model and you're interested in ways to optimize memory to avoid ``Out of Memory`` (OOM) errors or simply to ooze more out of your GPU. Well, you _might_ be in luck (if gradients take up a portion of your memory and you do not need to do gradient accumulation). We will explore the following: 1. What takes up memory during your training or finetuning loop, 2. How to capture and visualize memory snapshots to determine the bottleneck, 3. The new ``Tensor.register_post_accumulate_grad_hook(hook)`` API, and finally, 4. How everything fits together in 10 lines to achieve memory savings. To run this tutorial, you will need: * PyTorch 2.1.0 or newer with ``torchvision`` * 1 CUDA GPU if you'd like to run the memory visualizations locally. Otherwise, this technique would benefit similarly on any device. Let us start by importing the required modules and models. We will use a vision transformer model from torchvision, but feel free to substitute with your own model. We will also use ``torch.optim.Adam`` as our optimizer, but, again, feel free to substitute with your own optimizer. .. GENERATED FROM PYTHON SOURCE LINES 31-39 .. code-block:: default import torch from torchvision import models from pickle import dump model = models.vit_l_16(weights='DEFAULT').cuda() optimizer = torch.optim.Adam(model.parameters()) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/vit_l_16-852ce7e3.pth 0%| | 0.00/1.13G [00:00 None: optimizer_dict[parameter].step() optimizer_dict[parameter].zero_grad() # Register the hook onto every parameter for p in model.parameters(): p.register_post_accumulate_grad_hook(optimizer_hook) # Now remember our previous ``train()`` function? Since the optimizer has been # fused into the backward, we can remove the optimizer step and zero_grad calls. def train(model): # create our fake image input: tensor shape is batch_size, channels, height, width fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda() # call our forward and backward loss = model.forward(fake_image) loss.sum().backward() # optimizer update --> no longer needed! # optimizer.step() # optimizer.zero_grad() .. GENERATED FROM PYTHON SOURCE LINES 199-211 That took about 10 lines of changes in our sample model, which is neat. However, for real models, it could be a fairly intrusive change to switch out the optimizer for an optimizer dictionary, especially for those who use ``LRScheduler``s or manipulate optimizer configuration throughout the training epochs. Working out this API with those changes will be more involved and will likely require moving more configuration into global state but should not be impossible. That said, a next step for PyTorch is to make this API easier to adopt with LRSchedulers and other features you are already used to. But let me get back to convincing you that this technique is worth it. We will consult our friend, the memory snapshot. .. GENERATED FROM PYTHON SOURCE LINES 211-231 .. code-block:: default # delete optimizer memory from before to get a clean slate for the next # memory snapshot del optimizer # tell CUDA to start recording memory allocations torch.cuda.memory._record_memory_history(enabled='all') # train 3 steps. note that we no longer pass the optimizer into train() for _ in range(3): train(model) # save a snapshot of the memory allocations s = torch.cuda.memory._snapshot() with open(f"snapshot-opt-in-bwd.pickle", "wb") as f: dump(s, f) # tell CUDA to stop recording memory allocations now torch.cuda.memory._record_memory_history(enabled=None) .. GENERATED FROM PYTHON SOURCE LINES 232-269 Yes, take some time to drag your snapshot into the CUDA Memory Visualizer. .. figure:: /_static/img/optim_step_in_bwd/snapshot_opt_in_bwd.jpg :alt: snapshot.png loaded into CUDA Memory Visualizer Several major observations: 1. There is no more optimizer step! Right...we fused that into the backward. 2. Likewise, the backward drags longer and there are more random allocations for intermediates. This is expected, as the optimizer step requires intermediates. 3. Most importantly! The peak memory is lower! It is now ~4GB (which I hope maps closely to your earlier expectation). Note that there is no longer any big chunk of memory allocated for the gradients compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed each gradient very quickly after they've been computed by moving the optimizer step as far ahead as we can. Woohoo! By the way, the other ~1.2GB of memory savings comes from breaking apart the optimizer into per-parameter optimizers, so the intermediates have proportionally shrunk. This detail is `less important` than the gradient memory savings, as you can get optimizer intermediates savings from just turning ``foreach=False`` without this technique. You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start of the backward step, when we still have activations in memory, where before, the peak was during the optimizer step when the activations had been freed. The ~0.4GB difference accounting for ~4.0GB - ~3.6GB is thus due to the activations memory. One can then imagine that this technique can be coupled with activations checkpointing for more memory wins. Conclusion """""""""" In this tutorial, we learned about the memory saving technique of fusing the optimizer into the backward step through the new ``Tensor.register_post_accumulate_grad_hook()`` API and *when* to apply this technique (when gradients memory is significant). Along the way, we also learned about memory snapshots, which are generally useful in memory optimization. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 28.538 seconds) .. _sphx_glr_download_intermediate_optimizer_step_in_backward_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: optimizer_step_in_backward_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: optimizer_step_in_backward_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_