Source code for torch_tensorrt.fx.fx2trt

import logging
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence

import numpy

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch.fx
from torch._ops import OpOverload
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata

from .converter_registry import CONVERTERS
from .input_tensor_spec import InputTensorSpec
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

    Callable[[torch.fx.GraphModule], None]

[docs]class TRTInterpreterResult(NamedTuple): engine: Any input_names: Sequence[str] output_names: Sequence[str] serialized_cache: bytearray
[docs]class TRTInterpreter(torch.fx.Interpreter): def __init__( self, module: torch.fx.GraphModule, input_specs: List[InputTensorSpec], explicit_batch_dimension: bool = False, explicit_precision: bool = False, logger_level=None, ): super().__init__(module) self.logger = trt.Logger(logger_level or trt.Logger.WARNING) self.builder = trt.Builder(self.logger) flag = 0 if explicit_batch_dimension: EXPLICIT_BATCH = 1 << (int)( trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH ) flag |= EXPLICIT_BATCH if explicit_precision: EXPLICIT_PRECISION = 1 << (int)( trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION ) flag |= EXPLICIT_PRECISION = self.builder.create_network(flag) missing_ops = self.validate_conversion() if missing_ops: warnings.warn( "Interpretation will fail due to missing operations \n" + "\n".join(f"{i}" for i in missing_ops) ) self.optimization_profiles: Optional[List] = None self.input_specs = input_specs self.input_specs_iter = 0 self.validate_input_specs() self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = [] self._itensor_to_tensor_meta: Dict[ trt.tensorrt.ITensor, TensorMetadata ] = dict() def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not assert ( has_batch_dim ), "It's required to specify batch dimension when it's explicit in TensorRT network." dynamic_dims = get_dynamic_dims(shape) if len(dynamic_dims): assert not, ( "Can't have dynamic dim when " f"batch dim is implicit, got {shape}." ) assert len( shape_ranges ), "shape_ranges must be provided when shape has dynamic dim." if self.optimization_profiles: assert len(shape_ranges) == len(self.optimization_profiles), ( "Number of optimization " f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" f" {len(shape_ranges)} provided." ) else: self.optimization_profiles = [ self.builder.create_optimization_profile() for _ in range(len(shape_ranges)) ] for shape_range in shape_ranges: assert ( len(shape_range) == 3 ), f"Expect three elements in shape_range, got {len(shape_range)}" assert all(len(s) == len(shape) for s in shape_range), ( "Expect elements in shape_range" f" {shape_range} have the same number of dimension as the provided shape {len(shape)}" ) for i in range(len(shape)): if i in dynamic_dims: assert all( shape_range[j][i] <= shape_range[j + 1][i] for j in range(2) ), ( "Expect dynamic dim" f" {i} to have incremental value for shapes in shape_range {shape_range}." ) else: assert all(s[i] == shape[i] for s in shape_range), ( f"Expect non dynamic dim {i} to be the same" f" for all shapes in shape_range {shape_range}." ) else: assert ( len(shape_ranges) == 0 ), "shape_ranges are provided for input that doesn't have dynamic dim." def validate_conversion(self): missing_converter = set() for node in self.module.graph.nodes: if node.op == "call_function" and not CONVERTERS.get( missing_converter.add(f"{node.op} {_get_qualified_name(}") elif node.op == "call_method" and not CONVERTERS.get( missing_converter.add(f"{node.op} torch.Tensor.{}") elif node.op == "call_module": submod = self.fetch_attr( submod_type = getattr(submod, "_base_class_origin", type(submod)) if not CONVERTERS.get(submod_type): missing_converter.add(f"{node.op} {torch.typename(submod_type)}") return missing_converter def run( self, max_batch_size=64, max_workspace_size=1 << 25, lower_precision=LowerPrecision.FP16, sparse_weights=False, force_fp32_output=False, strict_type_constraints=False, algorithm_selector=None, timing_cache=None, profiling_verbosity=None, tactic_sources=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. Args: max_batch_size: set accordingly for maximum batch size you will use. max_workspace_size: set to the maximum size we can afford for temporary buffer lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity force_fp32_output: force output to be fp32 strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. algorithm_selector: set up algorithm selection for certain layer timing_cache: enable timing cache for TensorRT profiling_verbosity: TensorRT logging level Return: TRTInterpreterResult """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and # force_fp32_output=False. self.output_fp16 = ( not force_fp32_output and lower_precision == LowerPrecision.FP16 ) if ( lower_precision == LowerPrecision.INT8 and not self.builder.platform_has_fast_int8 ): raise RuntimeError("Current platform doesn't support fast native int8!") if ( lower_precision == LowerPrecision.FP16 and not self.builder.platform_has_fast_fp16 ): warnings.warn("Current platform doesn't support fast native fp16!") self.input_specs_iter = 0 run_module_start_time = super().run() f"TRT INetwork construction elapsed time: { - run_module_start_time}" ) build_engine_start_time = self.builder.max_batch_size = max_batch_size builder_config = self.builder.create_builder_config() builder_config.max_workspace_size = max_workspace_size cache = None if timing_cache: cache_file = numpy.array(timing_cache) cache = builder_config.create_timing_cache(cache_file.tobytes()) else: cache = builder_config.create_timing_cache(b"") builder_config.set_timing_cache(cache, False) if trt.__version__ >= "8.2": builder_config.profiling_verbosity = ( profiling_verbosity if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) if lower_precision == LowerPrecision.INT8: builder_config.set_flag(trt.BuilderFlag.INT8) if sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) if self.optimization_profiles: for optimization_profile in self.optimization_profiles: builder_config.add_optimization_profile(optimization_profile) if algorithm_selector: builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) builder_config.algorithm_selector = algorithm_selector if tactic_sources is not None: builder_config.set_tactic_sources(tactic_sources=tactic_sources) engine = self.builder.build_engine(, builder_config) assert engine serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() else bytearray() ) f"Build TRT engine elapsed time: { - build_engine_start_time}" ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache ) def run_node(self, n): self._cur_node_name = str(n) # add "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta n.kwargs = kwargs # run the node trt_node = super().run_node(n) # remove "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) del kwargs["_itensor_to_tensor_meta"] n.kwargs = kwargs if isinstance(trt_node, trt.tensorrt.ITensor): self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") return trt_node def placeholder(self, target, args, kwargs): self._input_names.append(target) shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[ self.input_specs_iter ] self.input_specs_iter += 1 if if has_batch_dim: shape = shape[1:] else: for i, shape_range in enumerate(shape_ranges): assert self.optimization_profiles self.optimization_profiles[i].set_shape(target, *shape_range) return name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype) ) def call_module(self, target, args, kwargs): assert isinstance(target, str) submod = self.fetch_attr(target) submod_type = getattr(submod, "_base_class_origin", type(submod)) converter = CONVERTERS.get(submod_type) if not converter: raise RuntimeError( f"Conversion of module of type {submod_type} not currently supported!" ) assert self._cur_node_name is not None return converter(, submod, args, kwargs, self._cur_node_name) def call_function(self, target, args, kwargs): converter = CONVERTERS.get(target) if not converter: raise RuntimeError( f"Conversion of function {torch.typename(target)} not currently supported!" ) assert self._cur_node_name is not None return converter(, target, args, kwargs, self._cur_node_name) def call_method(self, target, args, kwargs): assert isinstance(target, str) converter = CONVERTERS.get(target) if not converter: raise RuntimeError( f"Conversion of method {target} not currently supported!" ) assert self._cur_node_name is not None return converter(, target, args, kwargs, self._cur_node_name) def output(self, target, args, kwargs): assert len(args) == 1 outputs = args[0] if isinstance(args[0], tuple) else (args[0],) if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") for i, output in enumerate(outputs): if any( op_name in"_") for op_name in ( "eq", "gt", "lt", "or", "xor", "and", "not", "ne", "isinf", "any", ) ): output_bool = True else: output_bool = False name = f"output{i}" = name if output_bool: output.dtype = trt.bool elif self.output_fp16 and output.dtype == trt.float32: output.dtype = trt.float16 self._output_names.append(name)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources