This post is a follow-up to our first entry in the multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch and a focus on latency and elastic scalability. We use torch.compile and torch.export to create highly optimized low latency versions of SAM2 that can be quickly scaled up on new instances.
By utilizing AOTInductor’s (AOTI) ahead-of-time compilation via torch.export, reduced precision, batched prompts and GPU preprocessing we observe up to 13x improvement in p90 execution latency and queue times compared to regular eager mode PyTorch.
We calculate our final results and demonstrate the improvement in a realistic deployment on auto-scaling cloud infrastructure from Modal.
p50 execution latency
(ms / improvement) |
p90 execution latency
(ms / improvement) |
|||
eager float32 | AOTI float16 | eager float32 | AOTI float16 | |
AMG | 741 | 112 (6.6x) | 1140 | 176 (6.5x) |
SPS | 98 | 20 (4.9x) | 130 | 28 (4.6x) |
MPS | 269 | 38 (7.1x) | 714 | 52 (13.7x) |
p50 queue time (ms / improvement) | p90 queue time (ms / improvement) | |||
eager float32 | AOTI float16 | eager float32 | AOTI float16 | |
AMG | 201 | 41 (4.9x) | 815 | 327 (2.6x) |
SPS | 31 | 33 (0.9x) | 441 | 49 (9.0x) |
MPS | 40 | 37 (1.1x) | 942 | 75 (12.6x) |
The Tasks
The first post focused on processing a small number of varying prompts (points of interest) per image. These points represented the center points of the ground truth masks. For this post, we’ll now focus on a broader set of tasks. Single prompt segmentation (SPS), multi prompt segmentation (MPS), automatic mask generation (AMG) which generates the full set of masks for the input image without a given set of prompts. The first post focused on MPS only.
The little star in the image represents a user prompt. For AMG there are no prompts and masks are filtered down heuristically from a dense grid of initial candidate prompts (guesses). For SPS and MPS user prompts are derived from the center points of AMG masks. For SPS we choose the mask with the largest area.
Note that SAM2 uses a different backbone than SAM1. In particular, we only consider the largest and most accurate sam2.1_hiera_large backbone for this blog.
We aggregate the scripts needed to reproduce the results in torchao’s example folder and incrementally upstream the more stable parts of the changes to the SAM2 model in torchao to the main SAM2 repository. So if you are interested in taking a look at the cutting-edge variant or would like to contribute experimental features, please don’t hesitate to reach out to the torchao repository and team. For the more stable and latest model version, please head on over to SAM2 directly.
Overview
We categorize the changes presented here into two. Fast changes constrain themselves to techniques that are not meant to affect model accuracy. Furious changes sacrifice some numerical accuracy for additional speed by making use of approximations such as low-precision data types.
Approximations may slightly lower precision metrics in favor of significantly improved performance while still passing an end-to-end check based on mean intersection over union (mIoU).
To measure the performance improvements we processed 1000 images, which were selected at random from the SAM2 validation dataset. We look at the p50 and p90 latency per image. To measure accuracy we consider the mIoU. Most notably for the AMG task we also define a fail count metric. We consider a comparison failed if the number of masks differs. This turns out to be a fairly unstable quantity and we can see that the other tasks are not as sensitive to small numeric changes as AMG.
The Setup
We are running the offline experiments on a regular H100 devserver, which is a fairly beefy and performant machine.
However, we try to look at these tasks with realistic constraints. In particular, we would like to emulate a server-side inference environment. That means we don’t use DataLoader to hide the latency of image preprocessing or decoding routines.
For the latency calculations we include decoding, segmentation and conversion of masks to a dictionary of run-length encoded masks. Or put differently, we exclude loading the images into in-memory host bytearrays and storing the resulting dictionaries as json files on disk. This is meant to emulate a more realistic setting.
More concretely, consider the code below for the routines we include in our measurements. For any task gen_masks
produces a batched bool Tensor bitmask that represents the corresponding object masks. We then compress this bitmask into a run length encoded (rle) format that can be used to transfer back the results from a remote server much more efficiently.
image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]
Optimizations
ao: eager code optimizations
The most effective tool for this work is the PyTorch autograd profiler combined with record_function
. To build this software, we’ve used the profiler repeatedly to observe the program and confirm the effectiveness of any changes. It’s also important to keep in mind that the profiler itself has overhead. The more data you collect, such as stack traces, the more overhead you introduce, which might skew the collected trace. But it is excellent to find synchronization points, space between kernels and GPU kernels that take a long time.
GPU traces help you understand bottlenecks that are not necessarily easily addressed by compile. We found that AutomaticMaskGeneration in particular is dominated by the data structure used to store the masks and by the routine used to convert the masks to a run-length encoded compressed format. We also found a large part of AMG performance is dominated by the large number of masks created as a single batch. Sometimes candidate masks can be filtered down to fewer candidates earlier in the postprocessing stage by reordering operations. This in turn significantly speeds up the later operations.
In order to confirm the accuracy of our implementation we first compare without any changes in settings and using float32 precision. We see that mIoU is unchanged and the masks match perfectly when using the exact same settings. This means that these eager mode changes did not affect the accuracy of these tasks.
AMG
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU / fail count | |
Baseline | 864 | 1144 | 4350 | reference |
AO | 693 | 786 | 4010 | 1 / 0 |
ao: batching prompts
Another lossless performance optimization that we were able to apply is batching the user input prompt calculations. When optimizing for latency at batch size 1 on a server-grade GPU such as an H100 we are often left with a lot of spare memory. We can easily trade off that memory for more performance by processing more points of interest (also called user prompts) at once. Remember that SAM2 is split into two parts: First the backbone (image encoder), second the prediction and decoding of masks based on a set of user prompts / points of interest. It is the second part where we may expect a larger or even varying number of inputs and it is this second part where we apply batching.
This causes a large increase in memory, but also much better latency. The baseline generates one mask per prompt in a loop. For AMG the baseline processes 64 prompts at once and all that is needed is to change it to 1024, which is the number of candidate prompts generated. For SPS we process one prompt at a time, but it’s still included below for completeness.
AMG
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU / fail count | |
Baseline | 864 | 1144 | 4350 | reference |
AO + batching | 613 | 706 | 33786 | 0.9999995 / 0 |
SPS
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | |
Baseline | 116 | 181 | 1337 | reference |
AO | 110 | 170 | 1339 | 1 |
MPS
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | |
Baseline | 276 | 681 | 1337 | reference |
AO + batching | 126 | 225 | 8021 | 0.9999992 |
As a technical side note: Most notably to enable batching for MPS, and to avoid a significant manual rewrite of the code base to support multiple prompts at the same time, we used a Tensor subclass we call MapTensor. A MapTensor allows us to pass a batch of N prompts, but have it advertise a batch size of 1. Any operation is then automatically broadcast to the wrapped Tensor and propagated throughout the prediction part of the model. This works because individual prompt predictions are independent of one another. This is very similar to torch.vmap.
center_points_torch = to_map_tensor(center_points_torch)
center_points_label_torch = to_map_tensor(center_points_label_torch)
masks, scores, _ = mask_generator.predictor.predict(
point_coords=center_points_torch,
point_labels=center_points_label_torch,
multimask_output=True,
return_logits=False,
return_type="torch",
)
# Unwrapping MapTensor
masks = masks.elems
scores = scores.elems
fast: fullgraph compilation
Just as with our first post, we first remove GPU syncs and graph breaks to make use of fullgraph compiled model code with max-autotune kernels where appropriate. After some rewriting, we are able to compile the image encoder and the prediction of masks.
We run the experiments twice to get a sense of the overhead due to compilation. We run it once in an environment with an empty TORCHINDUCTOR_CACHE_DIR and then again while ingesting the artifacts from the previous run. In particular, auto-tuning can take a long time and happens on the first call in a pristine environment. We call the second run “warm”. The first iteration is typically expected to be slow due to various other related initialization processes, but compile increases it significantly, even if an existing cache is used and the same exact shapes are fed again. Having said that, an overhead of a few seconds in a warm environment is often still stomachable on the very first call.
Most of these drawbacks can be mitigated and compiling causes a significant improvement in latency and reduction in memory.
AMG
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU /
fail count |
first iteration
(ms) |
|
AO + batching | 613 | 706 | 33786 | 0.9999995 / 0 | 1125 |
+ compile (cold) | 423 | 513 | 29349 | skipped | 404866 |
+ compile (warm) | 439 | 530 | 29349 | 0.994 / 190 | 8544 |
The number of masks produced per mask can vary slightly when using automatic mask segmentation. There is ambiguity in the number of masks per object the model may produce. For example, a car may be subdivided into frames, windows and doors or treated as a whole. When a modification causes the number of masks to change, we consider the comparison failed and we only calculate the mIoU on masks with an exact match. This does not apply to the other tasks. We found that the number of masks generated is very sensitive to small numerical changes. The other tasks use the same code and MPS in particular can help us further verify correctness.
SPS
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | first iteration
(ms) |
|
AO | 110 | 170 | 1339 | 1 | 562 |
+ compile (cold) | 102 | 158 | 1343 | skipped | 319954 |
+ compile (warm) | 100 | 160 | 1302 | 0.9999 | 8947 |
MPS
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | first iteration
(ms) |
|
AO + batching | 126 | 225 | 8021 | 0.9999992 | 504 |
+ compile (cold) | 129 | 215 | 8021 | skipped | 333308 |
+ compile (warm) | 113 | 213 | 8021 | 0.998 | 8617 |
furious: TF32, float16 and GPU preprocessing
We found that using float16 is the right level of precision for a few significant subcomponents of the model. In particular, the image encoder and mask decoder weights can be converted entirely to float16. We can also use TensorFloat32 precision for the remaining float32 matrix operations. It should be possible to further reduce the precision and we may address this in a future post. We also move image preprocessing such as image normalization onto the GPU with the furious mode. We can’t use GPU decoding (nvJPEG) routines, because the differences are too significant and the model suffers from significant degradation in mIoU, so image decoding still happens on the CPU.
AMG
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU /
fail count |
|
AO
+ batching + compile (warm) |
439 | 530 | 29349 | 0.994 / 190 |
+ furious | 165 | 240 | 28335 | 0.978 / 306 |
This causes a significant degradation in mIoU for the AMG task, but doesn’t affect the other tasks. After an in-depth investigation, we still chalk this up to numerical instability and reordering of operations. More work is needed to further investigate this and it may not be interesting to run the AMG task in lower precision. The other tasks, however, benefit drastically in latency with minimal changes in mIoU.
SPS
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | |
AO
+ compile (warm) |
100 | 160 | 1302 | 0.9999 |
+ furious | 32 | 63 | 861 | 0.9997 |
MPS
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | |
AO
+ batching + compile (warm) |
113 | 213 | 8021 | 0.998 |
+ furious | 36 | 64 | 4222 | 0.997 |
AOTInductor’s (AOTI) ahead-of-time compilation via torch.export
When scaling elastically it often is not possible to accommodate long startup times. That means the first iteration cannot be slow, but we must quickly deliver results. This is when torch.compile’s current compilation overhead can get in the way. To address this we can use AOTInductor’s (AOTI) ahead-of-time compilation via torch.export. AOTI lets us compile the model on a representative input and store the resulting code in a binary that is quick to load and run.
AOTI via torch.export is a new feature and we currently can’t export everything that is compilable. We’ve been able to export the image encoder for all tasks but have only been able to export the mask prediction for the AMG and SPS tasks due to varying prompts. torch.export also supports dynamic shapes, but we need to invest a bit more time to prepare the code for it.
AMG: AO + batching + furious
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU /
fail count |
first iteration
(ms) |
|
+ compile (warm) | 165 | 240 | 28335 | 0.978 / 306 | 10341 |
+ load export
(cold) |
162 | 233 | 27927 | 0.974 / 308 | 906 |
SPS: AO + furious
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | first iteration
(ms) |
|
+ compile (warm) | 32 | 63 | 861 | 0.9997 | 7989 |
+ load export
(cold) |
35 | 66 | 1686 | 0.9997 | 763 |
Note that loading the exported model significantly increases memory. It likely only increases peak memory utilization, because initialization really needs to be delayed before loading up an exported model to avoid having twice the weights in memory at once. This is something we could address, but the memory consumption is nowhere near the limit. We don’t see an increase in the other tasks, because AMG and MPS peak memory is dominated by processing batches of masks. One way to reduce that could be to operate on masks in the rle format (or some other sparse format) earlier on, but for now, there is no reason for this given the current memory consumption and focus on latency.
MPS: AO + batching + furious
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | first iteration
(ms) |
|
+ compile (warm) | 36 | 64 | 4222 | 0.997 | 9626 |
+ load export
(cold) |
43 | 72 | 3813 | 0.997 | 747 |
Using export by itself doesn’t seem to benefit from extensive warmup and can be run in a pristine new inductor cache directory. But again, we do not evict the CUDA cache or other caches. In the section on Modal, we are running some of these experiments in a pristine environment.
When only processing 1000 images in a new process, using export can really be worth it to save out on compile and other cold start overhead.
bonus: More GPU preprocessing
At this point, the latency is fairly low. In particular, for the SPS and MPS tasks we are processing at around 30ms to 40ms. Let’s bring back the pseudo-code from the setup section again.
image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]
Further profiling showed that at this point decode_img_bytes
takes about 10ms. In particular, it uses torchvision’s ToTensor transform to convert from a numpy Tensor to a scaled, float32 torch.Tensor. The bytes passed to ToTensor have already been decoded and converted to an numpy ndarray. By slightly rewriting ToTensor, using torchvision’s v2 API and moving the uint8 decoded smaller integer Tensor to GPU first before scaling, we can gain another 10ms in latency. Without including decode_img_bytes
in our analysis we would have missed this opportunity that has real-world impact on server-side inference.
image_tensor = torch.from_numpy(image_tensor)
image_tensor = image_tensor.permute((2, 0, 1))
image_tensor = image_tensor.cuda()
image_tensor = v2.ToDtype(torch.float32, scale=True)( image_tensor)
Note in particular that using pinned memory to perform asynchronous data transfers doesn’t apply, since the time it takes to move the Tensor into pinned memory isn’t worth the gain in asynchronicity for this data movement. For future work, we might want to explore further improvements here by using more advanced direct memory transfer techniques.
AMG: AO + batching + furious
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU /
fail count |
first iteration
(ms) |
|
+ load export
(cold) |
162 | 233 | 27927 | 0.974 / 308 | 906 |
+ load export (warm) | 157 | 230 | 27927 | 0.974 / 308 | 799 |
+ load export (warm)
+ preproc |
136 | 208 | 27950 | 0.977 / 311 | 908 |
SPS: AO + furious
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | first iteration
(ms) |
|
+ load export
(cold) |
35 | 66 | 1686 | 0.9997 | 763 |
+ load export (warm) | 31 | 63 | 1686 | 0.9997 | 683 |
+ load export (warm)
+ preproc |
19 | 25 | 1711 | 0.9997 | 658 |
MPS: AO + batching + furious
p50 latency (ms) | p90 latency (ms) | memory (MiB) | mIoU | first iteration
(ms) |
|
+ load export
(cold) |
43 | 72 | 3813 | 0.997 | 747 |
+ load export (warm) | 53 | 81 | 3813 | 0.997 | 807 |
+ load export (warm)
+ preproc |
31 | 41 | 3837 | 0.997 | 671 |
This small change has a significant impact on the SPS and MPS task.
Deploying on Modal
Finally, we deployed our optimized inference onto Modal, a serverless infrastructure provider, to demonstrate that the benefits of these optimizations can be realized in a more realistic deployment setting.
In particular, compilation and AOTI via torch.export requires extra work. In a naïve deployment that work might be added to every single inference execution, adding latency that dwarfs any improvements from a faster model. This is particularly challenging with elastic or autoscaling infrastructure, where replicas of our inference service need to be regularly and automatically created and destroyed.
We share a deployment script in the torchao repository (cli_on_modal.py) to demonstrate one pattern for an elastic deployment. We build the exported models ahead of time and then upload them to distributed storage. Relative to eager execution, this adds a bit of extra work when replicas spin up since they need to read this data over a network, but this is far less costly than compilation or export.
We benchmarked this deployment with a large batch inference workload: sending 1000 images for concurrent processing. The deployment scales up to ten replicas on ten GPUs at peak and scales down to zero GPUs when inactive.
First, let’s look at the execution latencies.
p50 execution latency
(ms / improvement) |
p90 execution latency
(ms / improvement) |
|||||
eager float32 | AOTI float16 | eager float32 | AOTI float16 | |||
Modal | Offline | Modal | Offline | |||
AMG | 741 | 112 (6.6x) | 136 (5.4x) | 1140 | 176 (6.5x) | 208 (5.5x) |
SPS | 98 | 20 (4.9x) | 19 (5.2x) | 130 | 28 (4.6x) | 25 (5.2x) |
MPS | 269 | 38 (7.1x) | 31 (8.7x) | 714 | 52 (13.7x) | 41 (17.4x) |
We notice that execution latencies on Modal and Offline are fairly close, especially relative to the baseline, indicating that optimizing the deployment offline was a reasonable proxy for optimizing the deployment directly.
In addition to execution latency, our batch workload has queueing time, since there are fewer replicas than there are inputs, and so some inputs have to wait in line.
p50 queue time (ms) | p90 queue time (ms) | |||
eager float32 | AOTI float16 | eager float32 | AOTI float16 | |
AMG | 201 | 41 (4.9x) | 815 | 327 (2.6x) |
SPS | 31 | 33 (0.9x) | 441 | 49 (9.0x) |
MPS | 40 | 37 (1.1x) | 942 | 75 (12.6x) |
Even though the queueing system provided by the infrastructure is unchanged, the queue latencies also decrease when we use our optimized model – in the p90 case by a factor of 2 to 12. That’s because when we finish previous inputs faster (from reduced execution latency) we can pull our next inputs sooner (reducing their queueing time).
If you’re interested in optimizing SAM2 inference or deployments further, don’t hesitate to reach out to us at the torchao repository!
Conclusions
We rewrote Meta’s original SAM2 in pure PyTorch with little loss of accuracy and a strong focus on latency. We deployed our optimized inference onto Modal, a serverless infrastructure provider, to demonstrate that the benefits of these optimizations can be realized in a more realistic deployment setting.
By utilizing AOTInductor’s (AOTI) ahead-of-time compilation via torch.export, reduced precision, batched prompts and GPU preprocessing we observe up to 13x improvement in p90 execution latency and queue times compared to regular eager mode PyTorch.
With elastic or autoscaling infrastructure, where replicas of our inference service need to be regularly and automatically created and destroyed, a naïve deployment of torch.compile can add work to inference execution that dwarfs any improvements from a faster model. By utilizing AOTInductor’s (AOTI) ahead-of-time compilation via torch.export, we are able to upload exported models ahead of time and read this data over a network, which enables us to get the benefits of compilation without significantly increased work.
For more details on how to reproduce the data in this blog post, check out the experiments folder of torchao. Please don’t hesitate to contact us or open an issue if you run into any technical issues.