Shortcuts

Source code for torch_tensorrt.fx.lower

import dataclasses as dc
import logging
from typing import Any, Callable, Optional, Sequence

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch.fx as fx
import torch.nn as nn
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
from torch.fx.passes.splitter_base import SplitResult

from .fx2trt import TRTInterpreter, TRTInterpreterResult
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import PassFunc, validate_inference
from .tools.timing_cache_utils import TimingCacheManager
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting
from .tracer.acc_tracer import acc_tracer
from .trt_module import TRTModule
from .utils import LowerPrecision

logger = logging.getLogger(__name__)

Input = Sequence[Any]


[docs]def compile( module: nn.Module, input, min_acc_module_size: int = 10, max_batch_size: int = 2048, max_workspace_size=1 << 25, explicit_batch_dimension=False, lower_precision=LowerPrecision.FP16, verbose_log=False, timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, dynamic_batch=True, is_aten=False, use_experimental_fx_rt=False, correctness_atol=1e-1, correctness_rtol=1e-1, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module into lowered module, or so called TRTModule. Args: module: Original module for lowering. input: Input for module. max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set) min_acc_module_size: Minimal number of nodes for an accelerated submodule max_workspace_size: Maximum size of workspace given to TensorRT. explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension. lower_precision: lower_precision config given to TRTModule. verbose_log: Enable verbose log for TensorRT if set True. timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True. cuda_graph_batch_size: Cuda graph batch size, default to be -1. dynamic_batch: batch dimension (dim=0) is dynamic. use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). Returns: A torch.nn.Module lowered by TensorRT. """ if use_experimental_fx_rt and not explicit_batch_dimension: raise ValueError( "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True" ) lower_setting = LowerSetting( max_batch_size=max_batch_size, min_acc_module_size=min_acc_module_size, max_workspace_size=max_workspace_size, explicit_batch_dimension=explicit_batch_dimension, lower_precision=lower_precision, verbose_log=verbose_log, timing_cache_prefix=timing_cache_prefix, save_timing_cache=save_timing_cache, cuda_graph_batch_size=cuda_graph_batch_size, dynamic_batch=dynamic_batch, is_aten=is_aten, use_experimental_rt=use_experimental_fx_rt, correctness_atol=correctness_atol, correctness_rtol=correctness_rtol, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, input)
@dc.dataclass class LowerTrtInterpreter: lower_setting: LowerSetting timing_cache_manager: TimingCacheManager @classmethod def create(cls, lower_setting): timing_cache_manager = TimingCacheManager( lower_setting.timing_cache_prefix, lower_setting.save_timing_cache ) return LowerTrtInterpreter(lower_setting, timing_cache_manager) def __call__(self, mod, input, split_name) -> TRTInterpreterResult: assert self.lower_setting.input_specs, "Can't find input specs for lowering!" logger.info( f"split_name={split_name}, input_specs={self.lower_setting.input_specs}" ) # Prepare algorithm selector and timing_cache for TRTInterpreter algo_selector = None if self.lower_setting.algo_selector: algo_selector = self.lower_setting.algo_selector(f"{split_name}.json") cache_data = None if self.timing_cache_manager: try: cache_data = self.timing_cache_manager.get_timing_cache_trt(split_name) logger.info("Timing cache is used!") except Exception as e: logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}") cache_data = None interpreter = TRTInterpreter( mod, input_specs=self.lower_setting.input_specs, explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, explicit_precision=self.lower_setting.explicit_precision, logger_level=( trt.Logger.VERBOSE if self.lower_setting.verbose_log else trt.Logger.WARNING ), ) interp_result: TRTInterpreterResult = interpreter.run( max_batch_size=self.lower_setting.max_batch_size, max_workspace_size=self.lower_setting.max_workspace_size, lower_precision=self.lower_setting.lower_precision, strict_type_constraints=self.lower_setting.strict_type_constraints, algorithm_selector=algo_selector, timing_cache=cache_data, profiling_verbosity=( trt.ProfilingVerbosity.DETAILED if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ), tactic_sources=self.lower_setting.tactic_sources, ) # Update timing cache file if needed timing_cache = interp_result.serialized_cache if timing_cache and self.timing_cache_manager: self.timing_cache_manager.update_timing_cache(split_name, timing_cache) return interp_result def default_split_function( model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting ) -> SplitResult: splitter_setting = TRTSplitterSetting() splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt splitter = TRTSplitter(model, inputs, settings=splitter_setting) splitter.node_support_preview() return splitter.generate_split_results() def create_lower_trt_interpreter(lower_setting: LowerSetting) -> LowerTrtInterpreter: return LowerTrtInterpreter.create(lower_setting) def default_lower_pass( create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter], ) -> PassFunc: def lower_pass( mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str ) -> nn.Module: """ Create a module transformation pass which lowers an `fx.GraphModule` into a `TRTModule` """ interpreter = create_trt_interpreter(lower_setting) interp_res: TRTInterpreterResult = interpreter(mod, input, module_name) if lower_setting.use_experimental_rt: import io from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interp_res.engine.serialize()) engine_str = engine_bytes.getvalue() trt_module = TorchTensorRTModule( engine_str, name=module_name, input_binding_names=interp_res.input_names, output_binding_names=interp_res.output_names, target_device=Device(f"cuda:{torch.cuda.current_device()}"), # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do ) return trt_module else: trt_module = TRTModule( engine=interp_res.engine, input_names=interp_res.input_names, output_names=interp_res.output_names, cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, ) return trt_module return lower_pass @dc.dataclass(frozen=True) class Lowerer: """Lowers a module using fx2trt. This is a composable class to facilitate fx2trt. A normal fx2trt process composes of the following passes to transform an `fx.GraphModule`: 1. trace - use torch.fx to trace the module so we can get the graph representation of the model. 2. split - the graph module is split into several submodules, running either via TensorRT, or via regular CUDA. For each split that need to run via TRT, the following passes are invoked: 3. `TRTInterpreter` - build the TRT engine for the submodule that can be supported through `TRTInterpreter`. 4. Wraps the executable TRT engine into `TRTModule`, which is an `nn.Module`. 5. The converted submodule is then set back onto the top-level module """ lower_pass_manager_builder: LowerPassManagerBuilder @classmethod def create( cls, lower_setting: LowerSetting, interpreter_builder: Callable = create_lower_trt_interpreter, split_func: Callable = default_split_function, ) -> "Lowerer": """Instantiate a `Lowerer` instance.""" if not lower_setting.is_aten: return cls( lower_pass_manager_builder=LowerPassManagerBuilder( lower_setting=lower_setting, trace_func=lambda module, inputs: acc_tracer.trace( module, inputs, # type: ignore[arg-type] ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, leaf_module_list=lower_setting.leaf_module_list, ), split_func=split_func, lower_func=default_lower_pass(interpreter_builder), ) ) # proxytensor_trace else: return cls( lower_pass_manager_builder=LowerPassManagerBuilder( lower_setting=lower_setting, trace_func=lambda module, inputs: aten_tracer.opt_trace( module, inputs ), split_func=split_func, lower_func=default_lower_pass(interpreter_builder), ) ) def __call__( self, module: nn.Module, inputs: Input, additional_inputs: Optional[Input] = None, fp16_conversion_fn: Optional[Callable[[Input], Input]] = None, ) -> nn.Module: lower_setting = self.lower_pass_manager_builder.lower_setting atol = lower_setting.correctness_atol rtol = lower_setting.correctness_rtol @validate_inference( atol=atol, rtol=rtol, ) def do_lower(module: nn.Module, inputs: Input) -> nn.Module: module.eval() if ( self.lower_pass_manager_builder.lower_setting.lower_precision == LowerPrecision.FP16 ): module.half() # A custom conversion function can be passed to the lowerer to # handle inputs with custom types. By default, just handle # tensors and NoneType. if fp16_conversion_fn is None: conversion_fn = lambda x: ( x.half() if x is not None and x.dtype == torch.float32 else x ) else: conversion_fn = fp16_conversion_fn inputs = tuple(conversion_fn(x) for x in inputs) if lower_setting.is_aten: pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( inputs, additional_inputs ) else: pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( inputs, additional_inputs ) lower_result = pm(module) return lower_result return do_lower(module, inputs)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources