by Team PyTorch

This post is a follow-up to our first entry in the multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch and a focus on latency and elastic scalability. We use torch.compile and torch.export to create highly optimized low latency versions of SAM2 that can be quickly scaled up on new instances.

By utilizing AOTInductor’s (AOTI) ahead-of-time compilation via torch.export, reduced precision, batched prompts and GPU preprocessing we observe up to 13x improvement in p90 execution latency and queue times compared to regular eager mode PyTorch.

We calculate our final results and demonstrate the improvement in a realistic deployment on auto-scaling cloud infrastructure from Modal.

p50 execution latency
(ms / improvement)
p90 execution latency
(ms / improvement)
eager float32 AOTI float16 eager float32 AOTI float16
AMG 741 112 (6.6x) 1140 176 (6.5x)
SPS 98 20 (4.9x) 130 28 (4.6x)
MPS 269 38 (7.1x) 714 52 (13.7x)
p50 queue time (ms / improvement) p90 queue time (ms / improvement)
eager float32 AOTI float16 eager float32 AOTI float16
AMG 201 41 (4.9x) 815 327 (2.6x)
SPS 31 33 (0.9x) 441 49 (9.0x)
MPS 40 37 (1.1x) 942 75 (12.6x)

The Tasks

The first post focused on processing a small number of varying prompts (points of interest) per image. These points represented the center points of the ground truth masks. For this post, we’ll now focus on a broader set of tasks. Single prompt segmentation (SPS), multi prompt segmentation (MPS), automatic mask generation (AMG) which generates the full set of masks for the input image without a given set of prompts. The first post focused on MPS only.

comparison of 3 images

The little star in the image represents a user prompt. For AMG there are no prompts and masks are filtered down heuristically from a dense grid of initial candidate prompts (guesses). For SPS and MPS user prompts are derived from the center points of AMG masks. For SPS we choose the mask with the largest area.

Note that SAM2 uses a different backbone than SAM1. In particular, we only consider the largest and most accurate sam2.1_hiera_large backbone for this blog.

We aggregate the scripts needed to reproduce the results in torchao’s example folder and incrementally upstream the more stable parts of the changes to the SAM2 model in torchao to the main SAM2 repository. So if you are interested in taking a look at the cutting-edge variant or would like to contribute experimental features, please don’t hesitate to reach out to the torchao repository and team. For the more stable and latest model version, please head on over to SAM2 directly.

Overview

We categorize the changes presented here into two. Fast changes constrain themselves to techniques that are not meant to affect model accuracy. Furious changes sacrifice some numerical accuracy for additional speed by making use of approximations such as low-precision data types.

Approximations may slightly lower precision metrics in favor of significantly improved performance while still passing an end-to-end check based on mean intersection over union (mIoU).

To measure the performance improvements we processed 1000 images, which were selected at random from the SAM2 validation dataset. We look at the p50 and p90 latency per image. To measure accuracy we consider the mIoU. Most notably for the AMG task we also define a fail count metric. We consider a comparison failed if the number of masks differs. This turns out to be a fairly unstable quantity and we can see that the other tasks are not as sensitive to small numeric changes as AMG.

The Setup

We are running the offline experiments on a regular H100 devserver, which is a fairly beefy and performant machine.

However, we try to look at these tasks with realistic constraints. In particular, we would like to emulate a server-side inference environment. That means we don’t use DataLoader to hide the latency of image preprocessing or decoding routines.

For the latency calculations we include decoding, segmentation and conversion of masks to a dictionary of run-length encoded masks. Or put differently, we exclude loading the images into in-memory host bytearrays and storing the resulting dictionaries as json files on disk. This is meant to emulate a more realistic setting.

More concretely, consider the code below for the routines we include in our measurements. For any task gen_masks produces a batched bool Tensor bitmask that represents the corresponding object masks. We then compress this bitmask into a run length encoded (rle) format that can be used to transfer back the results from a remote server much more efficiently.

image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]

Optimizations

ao: eager code optimizations

The most effective tool for this work is the PyTorch autograd profiler combined with record_function. To build this software, we’ve used the profiler repeatedly to observe the program and confirm the effectiveness of any changes. It’s also important to keep in mind that the profiler itself has overhead. The more data you collect, such as stack traces, the more overhead you introduce, which might skew the collected trace. But it is excellent to find synchronization points, space between kernels and GPU kernels that take a long time.

GPU traces help you understand bottlenecks that are not necessarily easily addressed by compile. We found that AutomaticMaskGeneration in particular is dominated by the data structure used to store the masks and by the routine used to convert the masks to a run-length encoded compressed format. We also found a large part of AMG performance is dominated by the large number of masks created as a single batch. Sometimes candidate masks can be filtered down to fewer candidates earlier in the postprocessing stage by reordering operations. This in turn significantly speeds up the later operations.

In order to confirm the accuracy of our implementation we first compare without any changes in settings and using float32 precision. We see that mIoU is unchanged and the masks match perfectly when using the exact same settings. This means that these eager mode changes did not affect the accuracy of these tasks.

AMG

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU / fail count
Baseline 864 1144 4350 reference
AO 693 786 4010 1 / 0

ao: batching prompts

Another lossless performance optimization that we were able to apply is batching the user input prompt calculations. When optimizing for latency at batch size 1 on a server-grade GPU such as an H100 we are often left with a lot of spare memory. We can easily trade off that memory for more performance by processing more points of interest (also called user prompts) at once. Remember that SAM2 is split into two parts: First the backbone (image encoder), second the prediction and decoding of masks based on a set of user prompts / points of interest. It is the second part where we may expect a larger or even varying number of inputs and it is this second part where we apply batching.

This causes a large increase in memory, but also much better latency. The baseline generates one mask per prompt in a loop. For AMG the baseline processes 64 prompts at once and all that is needed is to change it to 1024, which is the number of candidate prompts generated. For SPS we process one prompt at a time, but it’s still included below for completeness.

AMG

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU / fail count
Baseline 864 1144 4350 reference
AO + batching 613 706 33786 0.9999995 / 0

SPS

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU
Baseline 116 181 1337 reference
AO 110 170 1339 1

MPS

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU
Baseline 276 681 1337 reference
AO + batching 126 225 8021 0.9999992

As a technical side note: Most notably to enable batching for MPS, and to avoid a significant manual rewrite of the code base to support multiple prompts at the same time, we used a Tensor subclass we call MapTensor. A MapTensor allows us to pass a batch of N prompts, but have it advertise a batch size of 1. Any operation is then automatically broadcast to the wrapped Tensor and propagated throughout the prediction part of the model. This works because individual prompt predictions are independent of one another. This is very similar to torch.vmap.

center_points_torch = to_map_tensor(center_points_torch)
center_points_label_torch = to_map_tensor(center_points_label_torch)
masks, scores, _ = mask_generator.predictor.predict(
    point_coords=center_points_torch,
    point_labels=center_points_label_torch,
    multimask_output=True,
    return_logits=False,
    return_type="torch",
)
# Unwrapping MapTensor
masks = masks.elems
scores = scores.elems

fast: fullgraph compilation

Just as with our first post, we first remove GPU syncs and graph breaks to make use of fullgraph compiled model code with max-autotune kernels where appropriate. After some rewriting, we are able to compile the image encoder and the prediction of masks.

We run the experiments twice to get a sense of the overhead due to compilation. We run it once in an environment with an empty TORCHINDUCTOR_CACHE_DIR and then again while ingesting the artifacts from the previous run. In particular, auto-tuning can take a long time and happens on the first call in a pristine environment. We call the second run “warm”. The first iteration is typically expected to be slow due to various other related initialization processes, but compile increases it significantly, even if an existing cache is used and the same exact shapes are fed again. Having said that, an overhead of a few seconds in a warm environment is often still stomachable on the very first call.

Most of these drawbacks can be mitigated and compiling causes a significant improvement in latency and reduction in memory.

AMG

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU /
fail count
first iteration
(ms)
AO + batching 613 706 33786 0.9999995 / 0 1125
+ compile (cold) 423 513 29349 skipped 404866
+ compile (warm) 439 530 29349 0.994 / 190 8544

The number of masks produced per mask can vary slightly when using automatic mask segmentation. There is ambiguity in the number of masks per object the model may produce. For example, a car may be subdivided into frames, windows and doors or treated as a whole. When a modification causes the number of masks to change, we consider the comparison failed and we only calculate the mIoU on masks with an exact match. This does not apply to the other tasks. We found that the number of masks generated is very sensitive to small numerical changes. The other tasks use the same code and MPS in particular can help us further verify correctness.

SPS

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU first iteration
(ms)
AO 110 170 1339 1 562
+ compile (cold) 102 158 1343 skipped 319954
+ compile (warm) 100 160 1302 0.9999 8947

MPS

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU first iteration
(ms)
AO + batching 126 225 8021 0.9999992 504
+ compile (cold) 129 215 8021 skipped 333308
+ compile (warm) 113 213 8021 0.998 8617

furious: TF32, float16 and GPU preprocessing

We found that using float16 is the right level of precision for a few significant subcomponents of the model. In particular, the image encoder and mask decoder weights can be converted entirely to float16. We can also use TensorFloat32 precision for the remaining float32 matrix operations. It should be possible to further reduce the precision and we may address this in a future post. We also move image preprocessing such as image normalization onto the GPU with the furious mode. We can’t use GPU decoding (nvJPEG) routines, because the differences are too significant and the model suffers from significant degradation in mIoU, so image decoding still happens on the CPU.

AMG

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU /
fail count
AO
+ batching
+ compile (warm)
439 530 29349 0.994 / 190
+ furious 165 240 28335 0.978 / 306

This causes a significant degradation in mIoU for the AMG task, but doesn’t affect the other tasks. After an in-depth investigation, we still chalk this up to numerical instability and reordering of operations. More work is needed to further investigate this and it may not be interesting to run the AMG task in lower precision. The other tasks, however, benefit drastically in latency with minimal changes in mIoU.

SPS

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU
AO
+ compile (warm)
100 160 1302 0.9999
+ furious 32 63 861 0.9997

MPS

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU
AO
+ batching
+ compile (warm)
113 213 8021 0.998
+ furious 36 64 4222 0.997

AOTInductor’s (AOTI) ahead-of-time compilation via torch.export

When scaling elastically it often is not possible to accommodate long startup times. That means the first iteration cannot be slow, but we must quickly deliver results. This is when torch.compile’s current compilation overhead can get in the way. To address this we can use AOTInductor’s (AOTI) ahead-of-time compilation via torch.export. AOTI lets us compile the model on a representative input and store the resulting code in a binary that is quick to load and run.

AOTI via torch.export is a new feature and we currently can’t export everything that is compilable. We’ve been able to export the image encoder for all tasks but have only been able to export the mask prediction for the AMG and SPS tasks due to varying prompts. torch.export also supports dynamic shapes, but we need to invest a bit more time to prepare the code for it.

AMG: AO + batching + furious

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU /
fail count
first iteration
(ms)
+ compile (warm) 165 240 28335 0.978 / 306 10341
+ load export
(cold)
162 233 27927 0.974 / 308 906

SPS: AO + furious

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU first iteration
(ms)
+ compile (warm) 32 63 861 0.9997 7989
+ load export
(cold)
35 66 1686 0.9997 763

Note that loading the exported model significantly increases memory. It likely only increases peak memory utilization, because initialization really needs to be delayed before loading up an exported model to avoid having twice the weights in memory at once. This is something we could address, but the memory consumption is nowhere near the limit. We don’t see an increase in the other tasks, because AMG and MPS peak memory is dominated by processing batches of masks. One way to reduce that could be to operate on masks in the rle format (or some other sparse format) earlier on, but for now, there is no reason for this given the current memory consumption and focus on latency.

MPS: AO + batching + furious

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU first iteration
(ms)
+ compile (warm) 36 64 4222 0.997 9626
+ load export
(cold)
43 72 3813 0.997 747

Using export by itself doesn’t seem to benefit from extensive warmup and can be run in a pristine new inductor cache directory. But again, we do not evict the CUDA cache or other caches. In the section on Modal, we are running some of these experiments in a pristine environment.

When only processing 1000 images in a new process, using export can really be worth it to save out on compile and other cold start overhead.

bonus: More GPU preprocessing

At this point, the latency is fairly low. In particular, for the SPS and MPS tasks we are processing at around 30ms to 40ms. Let’s bring back the pseudo-code from the setup section again.

image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]

Further profiling showed that at this point decode_img_bytes takes about 10ms. In particular, it uses torchvision’s ToTensor transform to convert from a numpy Tensor to a scaled, float32 torch.Tensor. The bytes passed to ToTensor have already been decoded and converted to an numpy ndarray. By slightly rewriting ToTensor, using torchvision’s v2 API and moving the uint8 decoded smaller integer Tensor to GPU first before scaling, we can gain another 10ms in latency. Without including decode_img_bytes in our analysis we would have missed this opportunity that has real-world impact on server-side inference.

image_tensor = torch.from_numpy(image_tensor)
image_tensor = image_tensor.permute((2, 0, 1))
image_tensor = image_tensor.cuda()
image_tensor = v2.ToDtype(torch.float32, scale=True)( image_tensor)

Note in particular that using pinned memory to perform asynchronous data transfers doesn’t apply, since the time it takes to move the Tensor into pinned memory isn’t worth the gain in asynchronicity for this data movement. For future work, we might want to explore further improvements here by using more advanced direct memory transfer techniques.

AMG: AO + batching + furious

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU /
fail count
first iteration
(ms)
+ load export
(cold)
162 233 27927 0.974 / 308 906
+ load export (warm) 157 230 27927 0.974 / 308 799
+ load export (warm)
+ preproc
136 208 27950 0.977 / 311 908

SPS: AO + furious

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU first iteration
(ms)
+ load export
(cold)
35 66 1686 0.9997 763
+ load export (warm) 31 63 1686 0.9997 683
+ load export (warm)
+ preproc
19 25 1711 0.9997 658

MPS: AO + batching + furious

p50 latency (ms) p90 latency (ms) memory (MiB) mIoU first iteration
(ms)
+ load export
(cold)
43 72 3813 0.997 747
+ load export (warm) 53 81 3813 0.997 807
+ load export (warm)
+ preproc
31 41 3837 0.997 671

This small change has a significant impact on the SPS and MPS task.

Deploying on Modal

Finally, we deployed our optimized inference onto Modal, a serverless infrastructure provider, to demonstrate that the benefits of these optimizations can be realized in a more realistic deployment setting.

In particular, compilation and AOTI via torch.export requires extra work. In a naïve deployment that work might be added to every single inference execution, adding latency that dwarfs any improvements from a faster model. This is particularly challenging with elastic or autoscaling infrastructure, where replicas of our inference service need to be regularly and automatically created and destroyed.

We share a deployment script in the torchao repository (cli_on_modal.py) to demonstrate one pattern for an elastic deployment. We build the exported models ahead of time and then upload them to distributed storage. Relative to eager execution, this adds a bit of extra work when replicas spin up since they need to read this data over a network, but this is far less costly than compilation or export.

We benchmarked this deployment with a large batch inference workload: sending 1000 images for concurrent processing. The deployment scales up to ten replicas on ten GPUs at peak and scales down to zero GPUs when inactive.

First, let’s look at the execution latencies.

p50 execution latency
(ms / improvement)
p90 execution latency
(ms / improvement)
eager float32 AOTI float16 eager float32 AOTI float16
Modal Offline Modal Offline
AMG 741 112 (6.6x) 136 (5.4x) 1140 176 (6.5x) 208 (5.5x)
SPS 98 20 (4.9x) 19 (5.2x) 130 28 (4.6x) 25 (5.2x)
MPS 269 38 (7.1x) 31 (8.7x) 714 52 (13.7x) 41 (17.4x)

We notice that execution latencies on Modal and Offline are fairly close, especially relative to the baseline, indicating that optimizing the deployment offline was a reasonable proxy for optimizing the deployment directly.

In addition to execution latency, our batch workload has queueing time, since there are fewer replicas than there are inputs, and so some inputs have to wait in line.

p50 queue time (ms) p90 queue time (ms)
eager float32 AOTI float16 eager float32 AOTI float16
AMG 201 41 (4.9x) 815 327 (2.6x)
SPS 31 33 (0.9x) 441 49 (9.0x)
MPS 40 37 (1.1x) 942 75 (12.6x)

Even though the queueing system provided by the infrastructure is unchanged, the queue latencies also decrease when we use our optimized model – in the p90 case by a factor of 2 to 12. That’s because when we finish previous inputs faster (from reduced execution latency) we can pull our next inputs sooner (reducing their queueing time).

If you’re interested in optimizing SAM2 inference or deployments further, don’t hesitate to reach out to us at the torchao repository!

Conclusions

We rewrote Meta’s original SAM2 in pure PyTorch with little loss of accuracy and a strong focus on latency. We deployed our optimized inference onto Modal, a serverless infrastructure provider, to demonstrate that the benefits of these optimizations can be realized in a more realistic deployment setting.

By utilizing AOTInductor’s (AOTI) ahead-of-time compilation via torch.export, reduced precision, batched prompts and GPU preprocessing we observe up to 13x improvement in p90 execution latency and queue times compared to regular eager mode PyTorch.

With elastic or autoscaling infrastructure, where replicas of our inference service need to be regularly and automatically created and destroyed, a naïve deployment of torch.compile can add work to inference execution that dwarfs any improvements from a faster model. By utilizing AOTInductor’s (AOTI) ahead-of-time compilation via torch.export, we are able to upload exported models ahead of time and read this data over a network, which enables us to get the benefits of compilation without significantly increased work.

For more details on how to reproduce the data in this blog post, check out the experiments folder of torchao. Please don’t hesitate to contact us or open an issue if you run into any technical issues.