Serving a Torch-TensorRT model with Triton¶
Optimization and deployment go hand in hand in a discussion about Machine Learning infrastructure. Once network level optimization are done to get the maximum performance, the next step would be to deploy it.
However, serving this optimized model comes with its own set of considerations and challenges like: building an infrastructure to support concurrent model executions, supporting clients over HTTP or gRPC and more.
The Triton Inference Server solves the aforementioned and more. Let’s discuss step-by-step, the process of optimizing a model with Torch-TensorRT, deploying it on Triton Inference Server, and building a client to query the model.
Step 1: Optimize your model with Torch-TensorRT¶
Most Torch-TensorRT users will be familiar with this step. For the purpose of this demonstration, we will be using a ResNet50 model from Torchhub.
We will be working in the //examples/triton
directory which contains the scripts used in this tutorial.
First pull the NGC PyTorch Docker container. You may need to create an account and get the API key from here. Sign up and login with your key (follow the instructions here after signing up).
# YY.MM is the yy:mm for the publishing tag for NVIDIA's Pytorch
# container; eg. 24.08
# NOTE: Use the publishing tag for both the PyTorch container and the Triton Containers
docker run -it --gpus all -v ${PWD}:/scratch_space nvcr.io/nvidia/pytorch:YY.MM-py3
cd /scratch_space
With the container we can export the model in to the correct directory in our Triton model repository. This export script uses the Dynamo frontend for Torch-TensorRT to compile the PyTorch model to TensorRT. Then we save the model using TorchScript as a serialization format which is supported by Triton.
import torch
import torch_tensorrt as torchtrt
import torchvision
import torch
import torch_tensorrt
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
# load model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")
# Compile with Torch TensorRT;
trt_model = torch_tensorrt.compile(model,
inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
enabled_precisions= {torch_tensorrt.dtype.f16}
)
ts_trt_model = torch.jit.trace(trt_model, torch.rand(1, 3, 224, 224).to("cuda"))
# Save the model
torch.jit.save(ts_trt_model, "/triton_example/model_repository/resnet50/1/model.pt")
You can run the script with the following command (from //examples/triton
)
docker run --gpus all -it --rm -v ${PWD}:/triton_example nvcr.io/nvidia/pytorch:YY.MM-py3 python /triton_example/export.py
This will save the serialized TorchScript version of the ResNet model in the right directory in the model repository.
Step 2: Set Up Triton Inference Server¶
If you are new to the Triton Inference Server and want to learn more, we highly recommend to checking our Github Repository.
To use Triton, we need to make a model repository. A model repository, as the name suggests, is a repository of the models the Inference server hosts. While Triton can serve models from multiple repositories, in this example, we will discuss the simplest possible form of the model repository.
The structure of this repository should look something like this:
model_repository
|
+-- resnet50
|
+-- config.pbtxt
+-- 1
|
+-- model.pt
There are two files that Triton requires to serve the model: the model itself
and a model configuration file which is typically provided in config.pbtxt
.
For the model we prepared in step 1, the following configuration can be used:
name: "resnet50"
backend: "pytorch"
max_batch_size : 0
input [
{
name: "x"
data_type: TYPE_FP32
dims: [ 1, 3, 224, 224 ]
}
]
output [
{
name: "output0"
data_type: TYPE_FP32
dims: [1, 1000]
}
]
The config.pbtxt
file is used to describe the exact model configuration
with details like the names and shapes of the input and output layer(s),
datatypes, scheduling and batching details and more. If you are new to Triton,
we highly encourage you to check out this section of our
documentation
for more details.
With the model repository setup, we can proceed to launch the Triton server with the docker command below. Refer this page for the pull tag for the container.
# Make sure that the TensorRT version in the Triton container
# and TensorRT version in the environment used to optimize the model
# are the same. Roughly, like publishing tags should have the same TensorRT version
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:YY.MM-py3 tritonserver --model-repository=/triton_example/model_repository
This should spin up a Triton Inference server. Next step, building a simple http client to query the server.
Step 3: Building a Triton Client to Query the Servers¶
Before proceeding, make sure to have a sample image on hand. If you don’t have one, download an example image to test inference. In this section, we will be going over a very basic client. For a variety of more fleshed out examples, refer to the Triton Client Repository
wget -O img1.jpg "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg"
We then need to install dependencies for building a python client. These will change from client to client. For a full list of all languages supported by Triton, please refer to Triton’s client repository.
pip install torchvision
pip install attrdict
pip install nvidia-pyindex
pip install tritonclient[all]
Let’s jump into the client. Firstly, we write a small preprocessing function to resize and normalize the query image.
import numpy as np
from torchvision import transforms
from PIL import Image
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype
# preprocessing function
def rn50_preprocess(img_path="/triton_example/img1.jpg"):
img = Image.open(img_path)
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return preprocess(img).unsqueeze(0).numpy()
transformed_img = rn50_preprocess()
Building a client requires three basic points. Firstly, we setup a connection with the Triton Inference Server.
# Setting up client
client = httpclient.InferenceServerClient(url="localhost:8000")
Secondly, we specify the names of the input and output layer(s) of our model. This can be obtained during export and should already be specified in your config.pbtxt
inputs = httpclient.InferInput("x", transformed_img.shape, datatype="FP32")
inputs.set_data_from_numpy(transformed_img, binary_data=True)
outputs = httpclient.InferRequestedOutput("output0", binary_data=True, class_count=1000)
Lastly, we send an inference request to the Triton Inference Server.
# Querying the server
results = client.infer(model_name="resnet50", inputs=[inputs], outputs=[outputs])
inference_output = results.as_numpy('output0')
print(inference_output[:5])
The output should look like below:
[b'12.468750:90' b'11.523438:92' b'9.664062:14' b'8.429688:136'
b'8.234375:11']
The output format here is <confidence_score>:<classification_index>
.
To learn how to map these to the label names and more, refer to Triton Inference Server’s
documentation.
You can try out this client quickly using
# Remember to use the same publishing tag for all steps (e.g. 24.08)
docker run -it --net=host -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:YY.MM-py3-sdk bash -c "pip install torchvision && python /triton_example/client.py"