Note
Click here to download the full example code
(beta) Running the compiled optimizer with an LR Scheduler
Created On: May 21, 2024 | Last Updated: May 21, 2024 | Last Verified: Nov 05, 2024
Author: Michael Lazos
The optimizer is a key algorithm for training any deep learning model.
In this example, we will show how to pair the optimizer, which has been compiled using torch.compile
,
with the LR schedulers to accelerate training convergence.
Note
This tutorial requires PyTorch 2.3.0 or later.
Model Setup
For this example, we’ll use a simple sequence of linear layers.
import torch
# Create simple model
model = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
# run forward pass
output = model(input)
# run backward to populate the grads for our optimizer below
output.sum().backward()
Setting up and running the compiled optimizer with LR Scheduler
In this section, we’ll use the Adam optimizer with LinearLR Scheduler
and create a helper function to wrap the step()
call for each of them
in torch.compile()
.
Note
torch.compile
is only supported on CUDA devices that have a compute capability of 7.0 or higher.
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the
# the optimizer with an LR Scheduler.
# Without this, torch.compile will recompile as the value of the LR
# changes.
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Warmup runs to compile the function
for _ in range(5):
fn()
print(opt.param_groups[0]["lr"])
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
tensor(0.0047)
tensor(0.0060)
tensor(0.0073)
tensor(0.0087)
tensor(0.0100)
Extension: What happens with a non-tensor LR?
For the curious, we will show how to peek into what happens with torch.compile
when we don’t wrap the
LR in a tensor.
# No longer wrap the LR in a tensor here
opt = torch.optim.Adam(model.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)
# Warmup runs to compile the function
# We will now recompile on each iteration
# as the value of the lr is mutated.
for _ in range(5):
fn()
[rank0]:V0418 18:40:36.286000 634 torch/_dynamo/guards.py:2791] [33/1] [__recompiles] Recompiling function wrapper in /usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py:473
[rank0]:V0418 18:40:36.286000 634 torch/_dynamo/guards.py:2791] [33/1] [__recompiles] triggered by the following guard failure(s):
[rank0]:V0418 18:40:36.286000 634 torch/_dynamo/guards.py:2791] [33/1] [__recompiles] - 33/0: Cache line invalidated because L['args'][0] got deallocated
[rank0]:V0418 18:40:36.304000 634 torch/_dynamo/guards.py:2791] [34/1] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:210
[rank0]:V0418 18:40:36.304000 634 torch/_dynamo/guards.py:2791] [34/1] [__recompiles] triggered by the following guard failure(s):
[rank0]:V0418 18:40:36.304000 634 torch/_dynamo/guards.py:2791] [34/1] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
[rank0]:V0418 18:40:39.484000 634 torch/_dynamo/guards.py:2791] [34/2] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:210
[rank0]:V0418 18:40:39.484000 634 torch/_dynamo/guards.py:2791] [34/2] [__recompiles] triggered by the following guard failure(s):
[rank0]:V0418 18:40:39.484000 634 torch/_dynamo/guards.py:2791] [34/2] [__recompiles] - 34/1: L['self'].param_groups[0]['lr'] == 0.003333333333333333
[rank0]:V0418 18:40:39.484000 634 torch/_dynamo/guards.py:2791] [34/2] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
[rank0]:V0418 18:40:41.906000 634 torch/_dynamo/guards.py:2791] [34/3] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:210
[rank0]:V0418 18:40:41.906000 634 torch/_dynamo/guards.py:2791] [34/3] [__recompiles] triggered by the following guard failure(s):
[rank0]:V0418 18:40:41.906000 634 torch/_dynamo/guards.py:2791] [34/3] [__recompiles] - 34/2: L['self'].param_groups[0]['lr'] == 0.004666666666666667
[rank0]:V0418 18:40:41.906000 634 torch/_dynamo/guards.py:2791] [34/3] [__recompiles] - 34/1: L['self'].param_groups[0]['lr'] == 0.003333333333333333
[rank0]:V0418 18:40:41.906000 634 torch/_dynamo/guards.py:2791] [34/3] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
[rank0]:V0418 18:40:44.299000 634 torch/_dynamo/guards.py:2791] [34/4] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:210
[rank0]:V0418 18:40:44.299000 634 torch/_dynamo/guards.py:2791] [34/4] [__recompiles] triggered by the following guard failure(s):
[rank0]:V0418 18:40:44.299000 634 torch/_dynamo/guards.py:2791] [34/4] [__recompiles] - 34/3: L['self'].param_groups[0]['lr'] == 0.006000000000000001
[rank0]:V0418 18:40:44.299000 634 torch/_dynamo/guards.py:2791] [34/4] [__recompiles] - 34/2: L['self'].param_groups[0]['lr'] == 0.004666666666666667
[rank0]:V0418 18:40:44.299000 634 torch/_dynamo/guards.py:2791] [34/4] [__recompiles] - 34/1: L['self'].param_groups[0]['lr'] == 0.003333333333333333
[rank0]:V0418 18:40:44.299000 634 torch/_dynamo/guards.py:2791] [34/4] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
[rank0]:V0418 18:40:46.697000 634 torch/_dynamo/guards.py:2791] [34/5] [__recompiles] Recompiling function step in /usr/local/lib/python3.10/dist-packages/torch/optim/adam.py:210
[rank0]:V0418 18:40:46.697000 634 torch/_dynamo/guards.py:2791] [34/5] [__recompiles] triggered by the following guard failure(s):
[rank0]:V0418 18:40:46.697000 634 torch/_dynamo/guards.py:2791] [34/5] [__recompiles] - 34/4: L['self'].param_groups[0]['lr'] == 0.007333333333333335
[rank0]:V0418 18:40:46.697000 634 torch/_dynamo/guards.py:2791] [34/5] [__recompiles] - 34/3: L['self'].param_groups[0]['lr'] == 0.006000000000000001
[rank0]:V0418 18:40:46.697000 634 torch/_dynamo/guards.py:2791] [34/5] [__recompiles] - 34/2: L['self'].param_groups[0]['lr'] == 0.004666666666666667
[rank0]:V0418 18:40:46.697000 634 torch/_dynamo/guards.py:2791] [34/5] [__recompiles] - 34/1: L['self'].param_groups[0]['lr'] == 0.003333333333333333
[rank0]:V0418 18:40:46.697000 634 torch/_dynamo/guards.py:2791] [34/5] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
With this example, we can see that we recompile the optimizer a few times
due to the guard failure on the lr
in param_groups[0]
.
Conclusion
In this tutorial we showed how to pair the optimizer compiled with torch.compile
with an LR Scheduler to accelerate training convergence. We used a model consisting
of a simple sequence of linear layers with the Adam optimizer paired
with a LinearLR scheduler to demonstrate the LR changing across iterations.
See also:
Compiled optimizer tutorial - an intro into the compiled optimizer.
Compiling the optimizer with PT2 - deeper technical details on the compiled optimizer.
Total running time of the script: ( 0 minutes 15.786 seconds)