Source code for torch_tensorrt.dynamo._compiler
from __future__ import annotations
import collections.abc
import logging
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
import torch
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
EngineCapability,
)
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import partitioning
from torch_tensorrt.dynamo._defaults import (
DEBUG,
DEVICE,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
)
from torch_tensorrt.dynamo._DryRunTracker import (
DryRunTracker,
PerSubgraphData,
dryrun_stats_display,
parse_non_trt_nodes,
)
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.utils import (
get_torch_inputs,
parse_complex_tensor_structs,
prepare_inputs,
set_log_level,
to_torch_device,
to_torch_tensorrt_device,
)
import torch_tensorrt
logger = logging.getLogger(__name__)
[docs]def compile(
exported_program: ExportedProgram,
inputs: Tuple[Any, ...],
*,
device: Optional[Union[Device, torch.device, str]] = DEVICE,
disable_tf32: bool = DISABLE_TF32,
sparse_weights: bool = SPARSE_WEIGHTS,
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
engine_capability: EngineCapability = ENGINE_CAPABILITY,
refit: bool = REFIT,
debug: bool = DEBUG,
capability: EngineCapability = EngineCapability.default,
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
workspace_size: int = WORKSPACE_SIZE,
dla_sram_size: int = DLA_SRAM_SIZE,
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Collection[Target]] = None,
torch_executed_modules: Optional[List[str]] = None,
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = DRYRUN,
hardware_compatible: bool = HARDWARE_COMPATIBLE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
Takes a existing TorchScript module and a set of settings to configure the compiler
and will convert methods to JIT Graphs which call equivalent TensorRT engines
Converts specifically the forward method of a TorchScript Module
Arguments:
exported_program (torch.export.ExportedProgram): Source module, running torch.export on a ``torch.nn.Module``
inputs (Tuple[Any, ...]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
to select device type. ::
input=[
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
torch_tensorrt.Input(
min_shape=(1, 224, 224, 3),
opt_shape=(1, 512, 512, 3),
max_shape=(1, 1024, 1024, 3),
dtype=torch.int32
format=torch.channel_last
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
Keyword Arguments:
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
debug (bool): Enable debuggable engine
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
torch_executed_ops (Collection[Target]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True
torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows)
max_aux_stream (Optional[int]): Maximum streams in the engine
version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines)
optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level.
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
"""
if debug:
set_log_level(logger.parent, logging.DEBUG)
if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
# Prepare torch_trt inputs
inputs = prepare_inputs(inputs)
device = to_torch_tensorrt_device(device)
if not isinstance(exported_program, ExportedProgram):
raise AssertionError(
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))
# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)
logger.debug("Lowered Input graph: " + str(gm.graph))
enabled_precisions = set(enabled_precisions)
if (
torch.float16 in enabled_precisions
or torch_tensorrt.dtype.half in enabled_precisions
):
precision = torch.float16
elif (
torch.float32 in enabled_precisions
or torch_tensorrt.dtype.float in enabled_precisions
):
precision = torch.float32
elif len(enabled_precisions) == 0:
logger.info(f"No precision specified, defaulting to {PRECISION}")
precision = PRECISION
else:
raise ValueError(
f"Precision {enabled_precisions} not supported in the Dynamo Path"
)
compilation_options = {
"precision": precision,
"debug": debug,
"device": device,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops
if torch_executed_ops is not None
else set(),
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"refit": refit,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
}
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
return compile_module(gm, inputs, settings)
def compile_module(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
"""Compile a traced FX module
Includes: Partitioning + Conversion Phases
Args:
module: FX GraphModule to convert
inputs: Inputs to the module
settings: Compilation settings
Returns:
Compiled FX GraphModule
"""
dryrun_tracker = DryRunTracker()
# Set torch-executed ops
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)
# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
)
dryrun_tracker.total_ops_in_graph = total_ops
dryrun_tracker.supported_ops_in_graph = num_supported_ops
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x)
)
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
sample_inputs, "torch_dtype"
)
dryrun_tracker.compilation_settings = settings
if settings.dryrun and settings.min_block_size > 1:
logger.info(
"It is recommended to run `dryrun` mode with `min_block_size=1`, "
"for the most thorough analysis"
)
# If the number of supported operations is 0 or less than the block size, skip the subgraph
# TODO: Add condition to second expression below when require_full_compilation is added
if num_supported_ops == 0 or (
num_supported_ops < settings.min_block_size and not settings.dryrun
):
logger.warning(
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
)
return gm
else:
logger.debug(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)
# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
partitioned_module, supported_ops = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
+ "Retrying with global partition.",
exc_info=True,
)
fast_partitioner_failed = True
settings.use_fast_partitioner = False
if not settings.use_fast_partitioner:
partitioned_module, supported_ops = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
# The global partitioner leaves non-TRT nodes as-is
if not settings.use_fast_partitioner:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
continue
subgraph_data = PerSubgraphData()
subgraph_data.subgraph_name = name
subgraph_data.subgraph_op_count = len(
[
node
for node in submodule.graph.nodes
if node.op in ("call_function", "call_method", "call_module")
]
)
# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
submodule,
sample_inputs,
to_torch_device(settings.device),
)
logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
str(name),
[input.shape for input in submodule_inputs],
str(submodule.graph),
)
assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module,
submodule,
submodule_inputs,
to_torch_device(settings.device),
name,
)
subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs(
submodule_inputs,
"shape",
lambda x: dict(x) if isinstance(x, dict) else tuple(x),
)
subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs(
submodule_inputs, "torch_dtype"
)
submodule_outputs = submodule(
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
)
subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs(
submodule_outputs,
"shape",
lambda x: dict(x) if isinstance(x, dict) else tuple(x),
)
subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs(
submodule_outputs, "dtype"
)
dryrun_tracker.tensorrt_graph_count += 1
dryrun_tracker.per_subgraph_data.append(subgraph_data)
# Create TRT engines from submodule
if not settings.dryrun:
trt_module = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)
trt_modules[name] = trt_module
sample_outputs = gm(
*get_torch_inputs(sample_inputs, to_torch_device(settings.device))
)
if not isinstance(sample_outputs, (list, tuple)):
sample_outputs = [sample_outputs]
dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs(
sample_outputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x)
)
dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs(
sample_outputs, "dtype"
)
# Replace all FX Modules with TRT Modules
for name, trt_module in trt_modules.items():
setattr(partitioned_module, name, trt_module)
# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
settings.use_fast_partitioner = True
dryrun_stats_display(dryrun_tracker, settings.dryrun)
return partitioned_module