Source code for torch_tensorrt.dynamo._exporter
import base64
import copy
import operator
from typing import Any, Dict, Optional, Sequence, Tuple, cast
import torch
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram, ExportGraphSignature
from torch.export.exported_program import (
CustomObjArgument,
InputKind,
InputSpec,
ModuleCallEntry,
ModuleCallSignature,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX
[docs]def export(
gm: torch.fx.GraphModule,
cross_compile_flag: Optional[bool] = False,
) -> ExportedProgram:
"""Export the result of TensorRT compilation into the desired output format.
Arguments:
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
inputs (torch.Tensor): Torch input tensors
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
"""
patched_module = transform(gm, cross_compile_flag)
exp_program = create_trt_exp_program(patched_module)
return exp_program
def transform(
gm: torch.fx.GraphModule,
cross_compile_flag: Optional[bool] = False,
) -> torch.fx.GraphModule:
"""
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
Inlining collapses submodules into nodes which is necessary for torch.export
serialization.
Arguments:
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
inputs (torch.Tensor): Torch input tensors
cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
Returns an inlined torch.fx.GraphModule
"""
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
# This transformed graph is meant to be consumed by `create_trt_exp_program`
gm = copy.deepcopy(gm)
# Inline TensorRT submodules
inline_trt_modules(gm, cross_compile_flag)
# Inline pytorch submodules
inline_torch_modules(gm)
# Clean the graph
gm.delete_all_unused_submodules()
gm.graph.eliminate_dead_code()
gm.graph.lint()
return gm
def lift(
gm: torch.fx.GraphModule, graph_signature: Any
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature, Dict[str, Any], Dict[str, Any]]:
"""
Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders.
Arguments:
gm (torch.fx.GraphModule): Unlifted GraphModule which contains parameters and buffers as get_attr nodes.
graph_signature (torch.export.ExportGraphSignature): Instance of ExportGraphSignature class created for the output ExportedProgram.
After lifting, this graph_signature will be modified with the parameters and buffers added appropriately.
Returns:
A lifted fx.GraphModule, modified graph_signature and a new state_dict
"""
# Get the state_dict of graph_module. This is different from exported_program.state_dict
# exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict
# has all parameters registered as torch.tensors.
state_dict = gm.state_dict()
constants = {}
fake_mode = detect_fake_mode(
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
)
assert fake_mode is not None
# This map stores the names of outputs (old to new)
# This is necessary to track because the output names can be changed when
# we convert graph constants to placeholder inputs below.
output_names = {}
for output_spec in graph_signature.output_specs:
output_names[output_spec.arg.name] = output_spec.arg.name
# Locate the user input to insert new placeholders before them
first_user_input = None
for node in gm.graph.nodes:
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
first_user_input = node
break
# At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0
# The input_specs should be of the form [params, buffers, constant_tensors, custom_obj, user_inputs]
non_user_input_idx = 0
for node in gm.graph.nodes:
if node.op == "get_attr":
lift_val = None
input_kind = None
if node.target not in state_dict:
constants[node.target] = getattr(gm, node.target)
input_kind = InputKind.CUSTOM_OBJ
lift_val = constants[node.target]
else:
lift_val = state_dict[node.target]
input_kind = InputKind.CONSTANT_TENSOR
# state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
for name, _ in gm.named_parameters():
if node.target == name:
input_kind = InputKind.PARAMETER
state_dict[name] = torch.nn.Parameter(state_dict[name])
break
for name, _ in gm.named_buffers():
if node.target == name:
input_kind = InputKind.BUFFER
break
assert lift_val is not None and input_kind is not None
# Replace get_attr nodes with placeholder nodes and copy metadata.
with gm.graph.inserting_before(first_user_input):
# Ensure name doesn't contain period as it is used for submodules
const_placeholder_name = node.target.replace(".", "_")
const_placeholder_node = gm.graph.placeholder(const_placeholder_name)
# Copy the node meta into this new placeholder node
const_placeholder_node.meta = node.meta
if isinstance(lift_val, torch.Tensor):
const_placeholder_node.meta["val"] = cast(
FakeTensor,
torch.empty_strided(
tuple(lift_val.shape),
tuple([1] * len(lift_val.shape)),
),
)
node.replace_all_uses_with(const_placeholder_node)
gm.graph.erase_node(node)
# Verify if the const_placeholder being added is one of the output nodes
# This happens if there is just a single static arange op in the graph
# https://github.com/pytorch/TensorRT/issues/3189
if const_placeholder_name in output_names:
output_names[const_placeholder_name] = const_placeholder_node.name
# Add these parameters/buffers/constants to the existing graph signature
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
if input_kind == InputKind.CUSTOM_OBJ:
input_spec_arg = CustomObjArgument(
name=const_placeholder_node.name, class_fqn=""
)
graph_signature.input_specs.insert(
non_user_input_idx,
InputSpec(
kind=input_kind,
arg=input_spec_arg,
target=node.target,
),
)
non_user_input_idx += 1
# Update output_specs with modified names. This only gets updated if the graph getattr nodes (weights)
# are also the outputs of the graph
for output_spec in graph_signature.output_specs:
output_spec.arg.name = output_names[output_spec.arg.name]
gm.graph.eliminate_dead_code()
gm.graph.lint()
return gm, graph_signature, state_dict, constants
def get_duplicate_nodes(
gm: torch.fx.GraphModule, submodule: torch.fx.GraphModule
) -> Tuple[Sequence[Any], Sequence[Any]]:
"""
We check if there are duplicate nodes when we copy submodule graph into gm.
Handle the case where the subgraph input placeholders are same as
gm placeholders. This happens when the first submodule in the graph is
a pytorch submodule
"""
submodule_placeholder_inputs = [
node for node in submodule.graph.nodes if node.op == "placeholder"
]
submodule_input_node_names = [node.name for node in submodule_placeholder_inputs]
gm_node_names = [node.name for node in gm.graph.nodes]
submodule_duplicate_inputs = [
node for node in submodule_placeholder_inputs if node.name in gm_node_names
]
gm_duplicate_inputs = [
node for node in gm.graph.nodes if node.name in submodule_input_node_names
]
return submodule_duplicate_inputs, gm_duplicate_inputs
def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Inline a submodule within the parent graph (gm). All `call_module` nodes
should be replaced by their nodes in the submodule.
"""
# Clean the graph
gm.graph.eliminate_dead_code()
gm.graph.lint()
for gm_node in gm.graph.nodes:
if gm_node.op == "call_module" and "_run_on_gpu" in gm_node.name:
submodule = getattr(gm, gm_node.name)
with gm.graph.inserting_before(gm_node):
# Get inputs of submodule node which are most likely outputs of a previous TRT node
# or a placeholder of the main graph
submodule_inputs = gm_node.args
submodule_duplicate_inputs, gm_duplicate_inputs = get_duplicate_nodes(
gm, submodule
)
assert len(submodule_duplicate_inputs) == len(gm_duplicate_inputs)
# Avoid creating new copies of duplicate inputs by creating a mapping
val_map = {}
for i in range(len(submodule_duplicate_inputs)):
val_map[submodule_duplicate_inputs[i]] = gm_duplicate_inputs[i]
# Copy all nodes in the submodule into gm and
# store the output node of this submodule which is now present in gm
submodule_output = gm.graph.graph_copy(submodule.graph, val_map)
# Get their references (since we copied) in the parent graph (gm)
if len(submodule_duplicate_inputs) == 0:
submodule_placeholder_input_names = [
node.name
for node in submodule.graph.nodes
if node.op == "placeholder"
]
gm_added_placeholder_inputs = [
node
for node in gm.graph.nodes
if node.name in submodule_placeholder_input_names
]
assert len(submodule_inputs) == len(gm_added_placeholder_inputs)
# Replace the added placeholder inputs with original inputs to this submodule node
for idx in range(len(gm_added_placeholder_inputs)):
gm_added_placeholder_inputs[idx].replace_all_uses_with(
submodule_inputs[idx]
)
# Erase the placeholder input nodes in the gm
for idx in range(len(gm_added_placeholder_inputs)):
gm.graph.erase_node(gm_added_placeholder_inputs[idx])
# Replace the pytorch submodule node (call_module) with the inlined subgraph output
gm_node.replace_all_uses_with(submodule_output)
# copy the attributes of the submodule into gm (graph_copy doesn't do this)
copy_submodule_attributes(gm, submodule, gm_node.name)
# Erase the pytorch submodule (call_module) node
gm.graph.erase_node(gm_node)
return gm
def copy_submodule_attributes(
gm: torch.fx.GraphModule, submodule: torch.fx.GraphModule, submodule_name: str
) -> None:
"""
The submodule parameters are available in the parent gm's state_dict, but they have
the submodule name as a prefix in their keys. For eg: gm.state_dict() would have
_run_on_gpu_0.conv.weight etc. Since we graph copied the submodule into gm, we should
also copy it's parameters and buffers into gm without the submodule namespace as prefix.
_assign_attr does exactly that. It creates a module for eg: conv, adds an attribute weight
to it and adds this conv module as an attribute to parent gm.
"""
from torch.export.unflatten import _assign_attr, _AttrKind
for key, value in submodule.named_parameters():
_assign_attr(value, gm, key, _AttrKind.PARAMETER)
for key, value in submodule.named_buffers():
_assign_attr(value, gm, key, _AttrKind.BUFFER)
def create_trt_exp_program(
gm: torch.fx.GraphModule,
) -> ExportedProgram:
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
and constructs an Exported Program object with the new IO node names and state_dict
"""
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]
assert output_nodes
output_nodes = output_nodes[0].args[0]
input_specs = [
InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target)
for node in input_nodes
]
output_specs = [
OutputSpec(OutputKind.USER_OUTPUT, TensorArgument(name=node.name), node.target)
for node in output_nodes
]
trt_graph_signature = ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
)
module_call_graph = [
ModuleCallEntry(
"",
ModuleCallSignature(
inputs=[],
outputs=[],
in_spec=gm.graph._codegen.pytree_info.in_spec,
out_spec=gm.graph._codegen.pytree_info.out_spec,
),
)
]
# Lift parameters/buffers/constants in the graph
# torch.export serialization expects them to be lifted
gm, trt_graph_signature, state_dict, constants = lift(gm, trt_graph_signature)
trt_exp_program = ExportedProgram(
root=gm,
graph=gm.graph,
graph_signature=trt_graph_signature,
state_dict=state_dict,
range_constraints={},
module_call_graph=module_call_graph,
constants=constants,
)
return trt_exp_program
def inline_trt_modules(
gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False
) -> torch.fx.GraphModule:
"""
Replace TRT submodules with trt engine nodes.
"""
for name, _ in gm.named_children():
if "_run_on_acc" not in name:
continue
# Get the TRT submodule
trt_module = getattr(gm, name)
# Ensure the trt module node in the main graph (gm) has inputs
trt_module_node = [node for node in gm.graph.nodes if node.name == name]
assert trt_module_node
trt_module_node = trt_module_node[0]
assert trt_module_node.args
if "val" not in trt_module_node.meta:
raise ValueError(
f"trt_module_node: {trt_module_node.name} does not have the metadata which should be set during dynamo compile_module step."
)
num_outputs = len(trt_module_node.meta["val"])
# Insert a call_function node to perform inference on TRT engine
with gm.graph.inserting_before(trt_module_node):
if not cross_compile_flag:
# for the normal workflow: use the execute_engine node
engine_name = f"{name}_engine"
setattr(gm, engine_name, trt_module.engine)
engine_node = gm.graph.get_attr(engine_name)
trt_node = gm.graph.call_function(
torch.ops.tensorrt.execute_engine.default,
(trt_module_node.args, engine_node),
)
# meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
# Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but
# for custom object nodes, it should be CustomObjArgument
engine_node.meta["val"] = CustomObjArgument(
name=engine_node.name, class_fqn=""
)
else:
# for the cross compile for windows workflow: use the no_op_placeholder node
engine_info = trt_module._pack_engine_info()
engine_bytes = engine_info[ENGINE_IDX]
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
trt_node = gm.graph.call_function(
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
(trt_module_node.args, *engine_info),
)
# set trt_node.meta with trt_module_node.meta
assert num_outputs > 0
trt_node.meta["val"] = trt_module_node.meta["val"]
if num_outputs == 1:
# Insert getitem nodes as outputs (for export serialization to work)
with gm.graph.inserting_after(trt_node):
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
getitem_output.meta["val"] = trt_node.meta["val"]
trt_module_node.replace_all_uses_with(getitem_output)
else:
# Multiple outputs case:
# Replace uses of submodule with the trt_node.
# getitem nodes are already added inherently by the partitioner
trt_module_node.replace_all_uses_with(trt_node)
getitem_nodes = trt_node.users
for idx, getitem_node in enumerate(getitem_nodes):
getitem_node.meta["val"] = trt_node.meta["val"][idx]
# Erase the TRT submodule (call_module) node.
gm.graph.erase_node(trt_module_node)
return gm
def replace_execute_engine_no_op_node(
exp_program: ExportedProgram,
) -> ExportedProgram:
gm = exp_program.graph_module
no_op_placeholder_nodes = []
for node in gm.graph.nodes:
if "no_op_placeholder_for_execute_engine" in node.name:
no_op_placeholder_nodes.append(node)
assert len(no_op_placeholder_nodes) > 0
for no_op_placeholder_node in no_op_placeholder_nodes:
if "val" not in no_op_placeholder_node.meta:
raise ValueError(f"metadata info is missing for the node: {node.name}")
with gm.graph.inserting_before(no_op_placeholder_node):
packed_engine_info = list(no_op_placeholder_node.args[1:])
engine_bytes = packed_engine_info[ENGINE_IDX]
engine_name = packed_engine_info[NAME_IDX]
packed_engine_info[ENGINE_IDX] = base64.b64decode(
engine_bytes.encode("utf-8")
)
trt_engine = torch.classes.tensorrt.Engine(tuple(packed_engine_info))
setattr(gm, engine_name, trt_engine)
engine_node = gm.graph.get_attr(engine_name)
trt_node = gm.graph.call_function(
torch.ops.tensorrt.execute_engine.default,
(no_op_placeholder_node.args[0], engine_node),
)
trt_node.meta["val"] = no_op_placeholder_node.meta["val"]
engine_node.meta["val"] = CustomObjArgument(
name=engine_node.name, class_fqn=""
)
if len(no_op_placeholder_node.meta["val"]) == 1:
with gm.graph.inserting_after(trt_node):
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
getitem_output.meta["val"] = trt_node.meta["val"]
no_op_placeholder_node.replace_all_uses_with(getitem_output)
else:
no_op_placeholder_node.replace_all_uses_with(trt_node)
getitem_nodes = trt_node.users
for idx, getitem_node in enumerate(getitem_nodes):
getitem_node.meta["val"] = trt_node.meta["val"][idx]
gm.graph.erase_node(no_op_placeholder_node)
gm.delete_all_unused_submodules()
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
return exp_program