Source code for torch_tensorrt.fx.lower
import dataclasses as dc
import logging
from typing import Any, Callable, Optional, Sequence
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch.fx as fx
import torch.nn as nn
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
from torch.fx.passes.splitter_base import SplitResult
from .fx2trt import TRTInterpreter, TRTInterpreterResult
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import PassFunc, validate_inference
from .tools.timing_cache_utils import TimingCacheManager
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting
from .tracer.acc_tracer import acc_tracer
from .trt_module import TRTModule
from .utils import LowerPrecision
logger = logging.getLogger(__name__)
Input = Sequence[Any]
[docs]def compile(
module: nn.Module,
input,
min_acc_module_size: int = 10,
max_batch_size: int = 2048,
max_workspace_size=1 << 25,
explicit_batch_dimension=False,
lower_precision=LowerPrecision.FP16,
verbose_log=False,
timing_cache_prefix="",
save_timing_cache=False,
cuda_graph_batch_size=-1,
dynamic_batch=True,
is_aten=False,
use_experimental_fx_rt=False,
correctness_atol=1e-1,
correctness_rtol=1e-1,
) -> nn.Module:
"""
Takes in original module, input and lowering setting, run lowering workflow to turn module
into lowered module, or so called TRTModule.
Args:
module: Original module for lowering.
input: Input for module.
max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
min_acc_module_size: Minimal number of nodes for an accelerated submodule
max_workspace_size: Maximum size of workspace given to TensorRT.
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
lower_precision: lower_precision config given to TRTModule.
verbose_log: Enable verbose log for TensorRT if set True.
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
save_timing_cache: Update timing cache with current timing cache data if set to True.
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
dynamic_batch: batch dimension (dim=0) is dynamic.
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
Returns:
A torch.nn.Module lowered by TensorRT.
"""
if use_experimental_fx_rt and not explicit_batch_dimension:
raise ValueError(
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
)
lower_setting = LowerSetting(
max_batch_size=max_batch_size,
min_acc_module_size=min_acc_module_size,
max_workspace_size=max_workspace_size,
explicit_batch_dimension=explicit_batch_dimension,
lower_precision=lower_precision,
verbose_log=verbose_log,
timing_cache_prefix=timing_cache_prefix,
save_timing_cache=save_timing_cache,
cuda_graph_batch_size=cuda_graph_batch_size,
dynamic_batch=dynamic_batch,
is_aten=is_aten,
use_experimental_rt=use_experimental_fx_rt,
correctness_atol=correctness_atol,
correctness_rtol=correctness_rtol,
)
lowerer = Lowerer.create(lower_setting=lower_setting)
return lowerer(module, input)
@dc.dataclass
class LowerTrtInterpreter:
lower_setting: LowerSetting
timing_cache_manager: TimingCacheManager
@classmethod
def create(cls, lower_setting):
timing_cache_manager = TimingCacheManager(
lower_setting.timing_cache_prefix, lower_setting.save_timing_cache
)
return LowerTrtInterpreter(lower_setting, timing_cache_manager)
def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
logger.info(
f"split_name={split_name}, input_specs={self.lower_setting.input_specs}"
)
# Prepare algorithm selector and timing_cache for TRTInterpreter
algo_selector = None
if self.lower_setting.algo_selector:
algo_selector = self.lower_setting.algo_selector(f"{split_name}.json")
cache_data = None
if self.timing_cache_manager:
try:
cache_data = self.timing_cache_manager.get_timing_cache_trt(split_name)
logger.info("Timing cache is used!")
except Exception as e:
logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}")
cache_data = None
interpreter = TRTInterpreter(
mod,
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
logger_level=(
trt.Logger.VERBOSE
if self.lower_setting.verbose_log
else trt.Logger.WARNING
),
)
interp_result: TRTInterpreterResult = interpreter.run(
max_batch_size=self.lower_setting.max_batch_size,
max_workspace_size=self.lower_setting.max_workspace_size,
lower_precision=self.lower_setting.lower_precision,
strict_type_constraints=self.lower_setting.strict_type_constraints,
algorithm_selector=algo_selector,
timing_cache=cache_data,
profiling_verbosity=(
trt.ProfilingVerbosity.DETAILED
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
tactic_sources=self.lower_setting.tactic_sources,
)
# Update timing cache file if needed
timing_cache = interp_result.serialized_cache
if timing_cache and self.timing_cache_manager:
self.timing_cache_manager.update_timing_cache(split_name, timing_cache)
return interp_result
def default_split_function(
model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting
) -> SplitResult:
splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size
splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
splitter.node_support_preview()
return splitter.generate_split_results()
def create_lower_trt_interpreter(lower_setting: LowerSetting) -> LowerTrtInterpreter:
return LowerTrtInterpreter.create(lower_setting)
def default_lower_pass(
create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter],
) -> PassFunc:
def lower_pass(
mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str
) -> nn.Module:
"""
Create a module transformation pass which lowers an `fx.GraphModule` into a
`TRTModule`
"""
interpreter = create_trt_interpreter(lower_setting)
interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
if lower_setting.use_experimental_rt:
import io
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
with io.BytesIO() as engine_bytes:
engine_bytes.write(interp_res.engine.serialize())
engine_str = engine_bytes.getvalue()
trt_module = TorchTensorRTModule(
engine_str,
name=module_name,
input_binding_names=interp_res.input_names,
output_binding_names=interp_res.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
# cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do
)
return trt_module
else:
trt_module = TRTModule(
engine=interp_res.engine,
input_names=interp_res.input_names,
output_names=interp_res.output_names,
cuda_graph_batch_size=lower_setting.cuda_graph_batch_size,
)
return trt_module
return lower_pass
@dc.dataclass(frozen=True)
class Lowerer:
"""Lowers a module using fx2trt.
This is a composable class to facilitate fx2trt. A normal fx2trt process
composes of the following passes to transform an `fx.GraphModule`:
1. trace - use torch.fx to trace the module so we can get the graph
representation of the model.
2. split - the graph module is split into several submodules,
running either via TensorRT, or via regular CUDA.
For each split that need to run via TRT, the following passes are
invoked:
3. `TRTInterpreter` - build the TRT engine for the submodule that
can be supported through `TRTInterpreter`.
4. Wraps the executable TRT engine into `TRTModule`, which is an `nn.Module`.
5. The converted submodule is then set back onto the top-level module
"""
lower_pass_manager_builder: LowerPassManagerBuilder
@classmethod
def create(
cls,
lower_setting: LowerSetting,
interpreter_builder: Callable = create_lower_trt_interpreter,
split_func: Callable = default_split_function,
) -> "Lowerer":
"""Instantiate a `Lowerer` instance."""
if not lower_setting.is_aten:
return cls(
lower_pass_manager_builder=LowerPassManagerBuilder(
lower_setting=lower_setting,
trace_func=lambda module, inputs: acc_tracer.trace(
module,
inputs, # type: ignore[arg-type]
ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list,
leaf_module_list=lower_setting.leaf_module_list,
),
split_func=split_func,
lower_func=default_lower_pass(interpreter_builder),
)
)
# proxytensor_trace
else:
return cls(
lower_pass_manager_builder=LowerPassManagerBuilder(
lower_setting=lower_setting,
trace_func=lambda module, inputs: aten_tracer.opt_trace(
module, inputs
),
split_func=split_func,
lower_func=default_lower_pass(interpreter_builder),
)
)
def __call__(
self,
module: nn.Module,
inputs: Input,
additional_inputs: Optional[Input] = None,
fp16_conversion_fn: Optional[Callable[[Input], Input]] = None,
) -> nn.Module:
lower_setting = self.lower_pass_manager_builder.lower_setting
atol = lower_setting.correctness_atol
rtol = lower_setting.correctness_rtol
@validate_inference(
atol=atol,
rtol=rtol,
)
def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
module.eval()
if (
self.lower_pass_manager_builder.lower_setting.lower_precision
== LowerPrecision.FP16
):
module.half()
# A custom conversion function can be passed to the lowerer to
# handle inputs with custom types. By default, just handle
# tensors and NoneType.
if fp16_conversion_fn is None:
conversion_fn = lambda x: (
x.half() if x is not None and x.dtype == torch.float32 else x
)
else:
conversion_fn = fp16_conversion_fn
inputs = tuple(conversion_fn(x) for x in inputs)
if lower_setting.is_aten:
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
inputs, additional_inputs
)
else:
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
inputs, additional_inputs
)
lower_result = pm(module)
return lower_result
return do_lower(module, inputs)