Source code for torch_tensorrt.dynamo._exporter
import copy
import operator
from typing import Any, Dict, Sequence, Tuple, cast
import torch
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram, ExportGraphSignature
from torch.export.exported_program import (
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch_tensorrt.dynamo import partitioning
[docs]def export(
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
*,
ir: str = "torchscript",
) -> ExportedProgram:
"""Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded.
> Note: When ExportedProgram becomes stable, this function will get merged into ``torch_tensorrt.dynamo.compile``
Arguments:
src_gm (torch.fx.GraphModule): Source module, generated by torch.export (The module provided to ``torch_tensorrt.dynamo.compile``)
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
Keyword Arguments:
inputs (Any): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
to select device type. ::
input=[
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
torch_tensorrt.Input(
min_shape=(1, 224, 224, 3),
opt_shape=(1, 512, 512, 3),
max_shape=(1, 1024, 1024, 3),
dtype=torch.int32
format=torch.channel_last
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
ir (str): torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program.
"""
if ir == "torchscript":
return torch.jit.trace(gm, inputs)
elif ir == "exported_program":
patched_module = transform(gm, inputs)
exp_program = create_trt_exp_program(patched_module)
return exp_program
else:
raise ValueError(
f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
)
def transform(
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
# Run shape analysis
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)
# Inline TensorRT submodules
inline_trt_modules(gm, outputs_map)
# Inline pytorch submodules
inline_torch_modules(gm)
# Lift constant buffers and parameters in the graph
# torch.export serialization expects them to be lifted
lift_constant_pass(gm)
# Clean the graph
gm.delete_all_unused_submodules()
gm.graph.eliminate_dead_code()
gm.graph.lint()
return gm
def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
fake_mode = detect_fake_mode(
tuple(
node.meta["val"] for node in trt_gm.graph.nodes if node.op == "placeholder"
)
)
first_user_input = None
for node in trt_gm.graph.nodes:
if node.op == "placeholder":
first_user_input = node
break
for node in trt_gm.graph.nodes:
if node.op == "get_attr":
constant_tensor = getattr(trt_gm, node.target)
with trt_gm.graph.inserting_before(first_user_input):
const_placeholder_node = trt_gm.graph.placeholder(node.target)
const_placeholder_node.meta = copy.deepcopy(node.meta)
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
constant_tensor
)
node.replace_all_uses_with(const_placeholder_node)
trt_gm.graph.erase_node(node)
trt_gm.graph.eliminate_dead_code()
trt_gm.graph.lint()
return trt_gm
def get_duplicate_nodes(
gm: torch.fx.GraphModule, submodule: torch.fx.GraphModule
) -> Tuple[Sequence[Any], Sequence[Any]]:
"""
We check if there are duplicate nodes when we copy submodule graph into gm.
Handle the case where the subgraph input placeholders are same as
gm placeholders. This happens when the first submodule in the graph is
a pytorch submodule
"""
submodule_placeholder_inputs = [
node for node in submodule.graph.nodes if node.op == "placeholder"
]
submodule_input_node_names = [node.name for node in submodule_placeholder_inputs]
gm_node_names = [node.name for node in gm.graph.nodes]
submodule_duplicate_inputs = [
node for node in submodule_placeholder_inputs if node.name in gm_node_names
]
gm_duplicate_inputs = [
node for node in gm.graph.nodes if node.name in submodule_input_node_names
]
return submodule_duplicate_inputs, gm_duplicate_inputs
def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Inline a submodule within the parent graph (gm). All `call_module` nodes
should be replaced by their submodule nodes.
"""
# Clean the graph
gm.graph.eliminate_dead_code()
gm.graph.lint()
for gm_node in gm.graph.nodes:
if gm_node.op == "call_module" and "_run_on_gpu" in gm_node.name:
submodule = getattr(gm, gm_node.name)
with gm.graph.inserting_before(gm_node):
# Get inputs of submodule node which are most likely outputs of a previous TRT node
# or a placeholder of the main graph
submodule_inputs = gm_node.args
submodule_duplicate_inputs, gm_duplicate_inputs = get_duplicate_nodes(
gm, submodule
)
assert len(submodule_duplicate_inputs) == len(gm_duplicate_inputs)
# Avoid creating new copies of duplicate inputs by creating a mapping
val_map = {}
for i in range(len(submodule_duplicate_inputs)):
val_map[submodule_duplicate_inputs[i]] = gm_duplicate_inputs[i]
# Copy all nodes in the submodule into gm and
# store the output node of this submodule which is now present in gm
submodule_output = gm.graph.graph_copy(submodule.graph, val_map)
# Get their references (since we copied) in the parent graph (gm)
if len(submodule_duplicate_inputs) == 0:
submodule_placeholder_input_names = [
node.name
for node in submodule.graph.nodes
if node.op == "placeholder"
]
gm_added_placeholder_inputs = [
node
for node in gm.graph.nodes
if node.name in submodule_placeholder_input_names
]
assert len(submodule_inputs) == len(gm_added_placeholder_inputs)
# Replace the added placeholder inputs with original inputs to this submodule node
for idx in range(len(gm_added_placeholder_inputs)):
gm_added_placeholder_inputs[idx].replace_all_uses_with(
submodule_inputs[idx]
)
# Erase the placeholder input nodes in the gm
for idx in range(len(gm_added_placeholder_inputs)):
gm.graph.erase_node(gm_added_placeholder_inputs[idx])
# Replace the pytorch submodule node (call_module) with the inlined subgraph output
gm_node.replace_all_uses_with(submodule_output)
# copy the attributes of the submodule into gm (graph_copy doesn't do this)
copy_submodule_attributes(gm, gm_node.name)
# Erase the pytorch submodule (call_module) node
gm.graph.erase_node(gm_node)
return gm
def copy_submodule_attributes(gm: torch.fx.GraphModule, submod_name: str) -> None:
"""
Copy the getattr attriibutes from submodule to parent module gm.
The graph_copy call doesn't do this for us unfortunately.
"""
for param in gm.named_parameters():
if param[0].startswith(submod_name + "."):
attr_name = param[0].replace(submod_name + ".", "")
gm.register_parameter(attr_name, param[1])
for buffer in gm.named_buffers():
if buffer[0].startswith(submod_name + "."):
attr_name = buffer[0].replace(submod_name + ".", "")
gm.register_buffer(attr_name, buffer[1])
def create_trt_exp_program(
gm: torch.fx.GraphModule,
) -> ExportedProgram:
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
and constructs an Exported Program object with the new IO node names and state_dict
"""
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]
assert output_nodes
output_nodes = output_nodes[0].args[0]
input_specs = [
InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target)
for node in input_nodes
]
output_specs = [
OutputSpec(OutputKind.USER_OUTPUT, TensorArgument(name=node.name), node.target)
for node in output_nodes
]
trt_graph_signature = ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
)
trt_exp_program = ExportedProgram(
gm, gm.graph, trt_graph_signature, gm.state_dict(), {}, [], [], []
)
return trt_exp_program
def inline_trt_modules(
gm: torch.fx.GraphModule, outputs_map: Dict[Any, Sequence[Any]]
) -> torch.fx.GraphModule:
"""
Replace TRT submodules with trt engine nodes.
"""
for name, _ in gm.named_children():
if "_run_on_acc" not in name:
continue
# Get the TRT submodule
trt_module = getattr(gm, name)
# Ensure the trt module node in the main graph (gm) has inputs
trt_module_node = [node for node in gm.graph.nodes if node.name == name]
assert trt_module_node
trt_module_node = trt_module_node[0]
assert trt_module_node.args
num_outputs = len(outputs_map[trt_module_node.name])
# Insert a call_function node to perform inference on TRT engine
with gm.graph.inserting_before(trt_module_node):
trt_node = gm.graph.call_function(
torch.ops.tensorrt.execute_engine.default,
(trt_module_node.args, trt_module.engine),
)
trt_node.meta["val"] = []
assert num_outputs > 0
# Generate meta data for TRT node (a FakeTensor with corresponding output shape)
for idx in range(num_outputs):
trt_node.meta["val"].append(
cast(
FakeTensor,
torch.empty_strided(
tuple(outputs_map[trt_module_node.name][idx]),
tuple([1] * len(outputs_map[trt_module_node.name][idx])),
),
)
)
if num_outputs == 1:
# Insert getitem nodes as outputs (for export serialization to work)
with gm.graph.inserting_after(trt_node):
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
getitem_output.meta["val"] = trt_node.meta["val"]
trt_module_node.replace_all_uses_with(getitem_output)
else:
# Multiple outputs case:
# Replace uses of submodule with the trt_node.
# getitem nodes are already added inherently by the partitioner
trt_module_node.replace_all_uses_with(trt_node)
getitem_nodes = trt_node.users
for idx, getitem_node in enumerate(getitem_nodes):
getitem_node.meta["val"] = trt_node.meta["val"][idx]
# Erase the TRT submodule (call_module) node.
gm.graph.erase_node(trt_module_node)
return gm