Source code for torch_tensorrt.fx.trt_module
from typing import Any, List, Sequence
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from .utils import Frameworks, unified_dtype_converter
[docs]class TRTModule(torch.nn.Module):
def __init__(
self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1
):
super(TRTModule, self).__init__()
self._register_state_dict_hook(TRTModule._on_state_dict)
self.engine = engine
self.input_names = input_names
self.output_names = output_names
self.cuda_graph_batch_size = cuda_graph_batch_size
self.initialized = False
if engine:
self._initialize()
def _initialize(self):
self.initialized = True
self.context = self.engine.create_execution_context()
# Indices of inputs/outputs in the trt engine bindings, in the order
# as they are in the original PyTorch model.
self.input_binding_indices_in_order: Sequence[int] = [
self.engine.get_binding_index(name) for name in self.input_names
]
self.output_binding_indices_in_order: Sequence[int] = [
self.engine.get_binding_index(name) for name in self.output_names
]
primary_input_outputs = set()
primary_input_outputs.update(self.input_binding_indices_in_order)
primary_input_outputs.update(self.output_binding_indices_in_order)
self.hidden_output_binding_indices_in_order: Sequence[int] = []
self.hidden_output_names: Sequence[str] = []
for i in range(
self.engine.num_bindings // self.engine.num_optimization_profiles
):
if i not in primary_input_outputs:
self.hidden_output_binding_indices_in_order.append(i)
self.hidden_output_names.append(self.engine.get_binding_name(i))
assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == (
len(self.input_names)
+ len(self.output_names)
+ len(self.hidden_output_names)
)
self.input_dtypes: Sequence[torch.dtype] = [
unified_dtype_converter(
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.input_binding_indices_in_order
]
self.input_shapes: Sequence[Sequence[int]] = [
tuple(self.engine.get_binding_shape(idx))
for idx in self.input_binding_indices_in_order
]
self.output_dtypes: Sequence[torch.dtype] = [
unified_dtype_converter(
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes: Sequence[torch.dtype] = [
unified_dtype_converter(
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.hidden_output_binding_indices_in_order
]
def _check_initialized(self):
if not self.initialized:
raise RuntimeError("TRTModule is not initialized.")
def _on_state_dict(self, state_dict, prefix, local_metadata):
self._check_initialized()
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
state_dict[prefix + "input_names"] = self.input_names
state_dict[prefix + "output_names"] = self.output_names
state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
engine_bytes = state_dict[prefix + "engine"]
logger = trt.Logger()
runtime = trt.Runtime(logger)
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
self.input_names = state_dict[prefix + "input_names"]
self.output_names = state_dict[prefix + "output_names"]
self._initialize()
def __getstate__(self):
state = self.__dict__.copy()
state["engine"] = bytearray(self.engine.serialize())
state.pop("context", None)
return state
def __setstate__(self, state):
logger = trt.Logger()
runtime = trt.Runtime(logger)
state["engine"] = runtime.deserialize_cuda_engine(state["engine"])
self.__dict__.update(state)
if self.engine:
self.context = self.engine.create_execution_context()
def forward(self, *inputs):
with torch.autograd.profiler.record_function("TRTModule:Forward"):
self._check_initialized()
with torch.autograd.profiler.record_function("TRTModule:ProcessInputs"):
assert len(inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
# This is only used when the trt engine is using implicit batch dim.
batch_size = inputs[0].shape[0]
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
bindings: List[Any] = [None] * (
len(self.input_names)
+ len(self.output_names)
+ len(self.hidden_output_names)
)
for i, input_name in enumerate(self.input_names):
assert inputs[
i
].is_cuda, f"{i}th input({input_name}) is not on cuda device."
assert (
inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}."
idx = self.input_binding_indices_in_order[i]
bindings[idx] = contiguous_inputs[i].data_ptr()
if not self.engine.has_implicit_batch_dimension:
self.context.set_binding_shape(
idx, tuple(contiguous_inputs[i].shape)
)
else:
assert inputs[i].size()[1:] == self.input_shapes[i], (
f"Shape mismatch for {i}th input({input_name}). "
f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}."
)
with torch.autograd.profiler.record_function("TRTModule:ProcessOutputs"):
# create output tensors
outputs: List[torch.Tensor] = []
for i, idx in enumerate(self.output_binding_indices_in_order):
if self.engine.has_implicit_batch_dimension:
shape = (batch_size,) + self.output_shapes[i]
else:
shape = tuple(self.context.get_binding_shape(idx))
output = torch.empty( # type: ignore[call-overload]
size=shape,
dtype=self.output_dtypes[i],
device=torch.cuda.current_device(),
)
outputs.append(output)
bindings[idx] = output.data_ptr()
for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
if self.engine.has_implicit_batch_dimension:
shape = (batch_size,) + self.hidden_output_shapes[i]
else:
shape = tuple(self.context.get_binding_shape(idx))
output = torch.empty( # type: ignore[call-overload]
size=shape,
dtype=self.hidden_output_dtypes[i],
device=torch.cuda.current_device(),
)
bindings[idx] = output.data_ptr()
with torch.autograd.profiler.record_function("TRTModule:TensorRTRuntime"):
if self.engine.has_implicit_batch_dimension:
self.context.execute_async(
batch_size, bindings, torch.cuda.current_stream().cuda_stream
)
else:
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
def enable_profiling(self, profiler: "trt.IProfiler" = None):
"""
Enable TensorRT profiling. After calling this function, TensorRT will report
time spent on each layer in stdout for each forward run.
"""
self._check_initialized()
if not self.context.profiler:
self.context.profiler = trt.Profiler() if profiler is None else profiler
def disable_profiling(self):
"""
Disable TensorRT profiling.
"""
self._check_initialized()
torch.cuda.synchronize()
del self.context
self.context = self.engine.create_execution_context()
def get_layer_info(self) -> str:
"""
Get layer info of the engine. Only support for TRT > 8.2.
"""
inspector = self.engine.create_engine_inspector()
return inspector.get_engine_information(trt.LayerInformationFormat.JSON)