• Docs >
  • Compiling SAM2 using the dynamo backend
Shortcuts

Compiling SAM2 using the dynamo backend

This example illustrates the state of the art model Segment Anything Model 2 (SAM2) optimized using Torch-TensorRT.

Segment Anything Model 2 is a foundation model towards solving promptable visual segmentation in images and videos. Install the following dependencies before compilation

pip install -r requirements.txt

Certain custom modifications are required to ensure the model is exported successfully. To apply these changes, please install SAM2 using the following fork (Installation instructions)

In the custom SAM2 fork, the following modifications have been applied to remove graph breaks and enhance latency performance, ensuring a more efficient Torch-TRT conversion:

  • Consistent Data Types: Preserves input tensor dtypes, removing forced FP32 conversions.

  • Masked Operations: Uses mask-based indexing instead of directly selecting data, improving Torch-TRT compatibility.

  • Safe Initialization: Initializes tensors conditionally rather than concatenating to empty tensors.

  • Standard Functions: Avoids special contexts and custom LayerNorm, relying on built-in PyTorch functions for better stability.

Import the following libraries

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_tensorrt
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam_components import SAM2FullModel

matplotlib.use("Agg")

Define the SAM2 model

Load the facebook/sam2-hiera-large pretrained model using SAM2ImagePredictor class. SAM2ImagePredictor provides utilities to preprocess images, store image features (via set_image function) and predict the masks (via predict function)

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

To ensure we export the entire model (image encoder and mask predictor) components successfully, we create a standalone module SAM2FullModel which uses these utilities from SAM2ImagePredictor class. SAM2FullModel performs feature extraction and mask prediction in a single step instead of two step process of SAM2ImagePredictor (set_image and predict functions)

class SAM2FullModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.image_encoder = model.forward_image
        self._prepare_backbone_features = model._prepare_backbone_features
        self.directly_add_no_mem_embed = model.directly_add_no_mem_embed
        self.no_mem_embed = model.no_mem_embed
        self._features = None

        self.prompt_encoder = model.sam_prompt_encoder
        self.mask_decoder = model.sam_mask_decoder

        self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)]

    def forward(self, image, point_coords, point_labels):
        backbone_out = self.image_encoder(image)
        _, vision_feats, _, _ = self._prepare_backbone_features(backbone_out)

        if self.directly_add_no_mem_embed:
            vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [
            feat.permute(1, 2, 0).view(1, -1, *feat_size)
            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
        ][::-1]
        features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}

        high_res_features = [
            feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"]
        ]

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=(point_coords, point_labels), boxes=None, masks=None
        )

        low_res_masks, iou_predictions, _, _ = self.mask_decoder(
            image_embeddings=features["image_embed"][-1].unsqueeze(0),
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=point_coords.shape[0] > 1,
            high_res_features=high_res_features,
        )

        out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
        return out

Initialize the SAM2 model with pretrained weights

Initialize the SAM2FullModel with the pretrained weights. Since we already initialized SAM2ImagePredictor, we can directly use the model from it (predictor.model). We cast the model to FP16 precision for faster performance.

encoder = predictor.model.eval().cuda()
sam_model = SAM2FullModel(encoder.half()).eval().cuda()

Load a sample image provided in the repository.

input_image = Image.open("./truck.jpg").convert("RGB")

Load an input image

Here’s the input image we are going to use

../../../_images/truck.jpg
input_image = Image.open("./truck.jpg").convert("RGB")

In addition to the input image, we also provide prompts as inputs which are used to predict the masks. The prompts can be a box, point as well as masks from previous iteration of prediction. We use a point as a prompt in this demo similar to the original notebook in the SAM2 repository

Preprocessing components

The following functions implement preprocessing components which apply transformations on the input image and transform given point coordinates. We use the SAM2Transforms available via the SAM2ImagePredictor class. To read more about the transforms, refer to https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py

def preprocess_inputs(image, predictor):
    w, h = image.size
    orig_hw = [(h, w)]
    input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0")

    point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0")
    point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0")

    point_coords = torch.as_tensor(
        point_coords, dtype=torch.float, device=predictor.device
    )
    unnorm_coords = predictor._transforms.transform_coords(
        point_coords, normalize=True, orig_hw=orig_hw[0]
    )
    labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device)
    if len(unnorm_coords.shape) == 2:
        unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]

    input_image = input_image.half()
    unnorm_coords = unnorm_coords.half()

    return (input_image, unnorm_coords, labels)

Post Processing components

The following functions implement postprocessing components which include plotting and visualizing masks and points. We use the SAM2Transforms to post process these masks and sort them via confidence score.

def postprocess_masks(out, predictor, image):
    """Postprocess low-resolution masks and convert them for visualization."""
    orig_hw = (image.size[1], image.size[0])  # (height, width)
    masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw)
    masks = (masks > 0.0).squeeze(0).cpu().numpy()
    scores = out["iou_predictions"].squeeze(0).cpu().numpy()
    sorted_indices = np.argsort(scores)[::-1]
    return masks[sorted_indices], scores[sorted_indices]


def show_mask(mask, ax, random_color=False, borders=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [
            cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
        ]
        mask_image = cv2.drawContours(
            mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
        )
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def visualize_masks(
    image, masks, scores, point_coords, point_labels, title_prefix="", save=True
):
    """Visualize and save masks overlaid on the original image."""
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(point_coords, point_labels, plt.gca())
        plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis("off")
        plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png")
        plt.close()

Preprocess the inputs

Preprocess the inputs. In the following snippet, torchtrt_inputs contains (input_image, unnormalized_coordinates and labels) The unnormalized_coordinates is the representation of the point and the label (= 1 in this demo) represents foreground point.

torchtrt_inputs = preprocess_inputs(input_image, predictor)

Torch-TensorRT compilation

Export the model in non-strict mode and perform Torch-TensorRT compilation in FP16 precision. We enable FP32 matmul accumulation using use_fp32_acc=True to preserve accuracy with the original Pytorch model.

exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False)
trt_model = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=torchtrt_inputs,
    min_block_size=1,
    enabled_precisions={torch.float16},
    use_fp32_acc=True,
)
trt_out = trt_model(*torchtrt_inputs)

Output visualization

Post process the outputs of Torch-TensorRT and visualize the masks using the post processing components provided above. The outputs should be stored in your current directory.

trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image)
visualize_masks(
    input_image,
    trt_masks,
    trt_scores,
    torch.tensor([[500, 375]]),
    torch.tensor([1]),
    title_prefix="Torch-TRT",
)
The predicted masks are as shown below
../../../_images/sam_mask1.png ../../../_images/sam_mask2.png ../../../_images/sam_mask3.png

References

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources