Lowering Phase¶
The lowering phase is made up out of passes which are operations which map a graph from a high level representation to a lower level one. Each pass does something specific for instance inlining method calls. The idea is to significantly reduce what the conversion phase needs to be able to handle when actually mapping to TensorRT. We aim for closer to 1->1 op conversion vs looking for applicable subgraphs, limiting the number of converters and reduce the scope of each converter.
You can see the effects of each pass by setting the log level to Level::kGraph
Passes Used¶
EliminateCommonSubexpression¶
Removes common subexpressions in the graph
Eliminate Dead Code¶
Dead code elimination will check if a node has side effects and not delete it if it does.
Eliminate Exception Or Pass Pattern¶
A common pattern in scripted modules are dimension guards which will throw exceptions if the input dimension is not what was expected.
%1013 : bool = aten::ne(%1012, %24) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:11
= prim::If(%1013) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:8
block0():
= prim::RaiseException(%23) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:249:12
-> ()
block1():
-> ()
Since we are resolving all of this at compile time and there are no exceptions in the TensorRT graph, we just remove it.
Eliminate Redundant Guards¶
Eliminate redundant guards for ops whose outputs are fully determined by their inputs i.e. if inputs to such ops are guarded we are allowed to remove a guard on ops’ outputs
Freeze Module¶
Freeze attributes and inline constants and modules. Propagates constants in the graph.
Fuse AddMM Branches¶
A common pattern in scripted modules is tensors of different dimensions use different constructions for implementing linear layers. We fuse these different variants into a single one that will get caught by the Unpack AddMM pass.
%ret : Tensor = prim::If(%622)
block0():
%ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
-> (%ret.1)
block1():
%output.1 : Tensor = aten::matmul(%x9.1, %3677)
%output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
-> (%output0.1)
We fuse this set of blocks into a graph like this:
%ret : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
Fuse Linear¶
Match the aten::linear
pattern and fuse it into a single aten::linear
This pass fuse the addmm or matmul + add generated by JIT back to linear
Fuse Flatten Linear¶
TensorRT implicitly flattens input layers into fully connected layers when they are higher than 1D. So when there is a
aten::flatten
-> aten::linear
pattern we remove the aten::flatten
.
Lower Graph¶
Given a graph with of a method which first argument is %self, lower it to a graph where all attributes accesses are replaced with explicit inputs of the graph (rather than results of prim::GetAttr executed on %self). Returns a tuple (graph, parameters) where the last module.parameters.size() inputs to the graph are the trainable parameters used in this method. The remaining inputs are the true inputs to the function.
Lower Tuples¶
LowerSimpleTuples
:
Removes tuples where TupleConstruct and TupleUnpack are matched but leaves tuples in place across if statements, loops, and as inputs/outputs
LowerAllTuples
:
Removes _all_ tuples and raises an error if some cannot be removed, this is used by ONNX to ensure there are not tuples before conversion, but will not work on graphs whose inputs contain tuples.
Module Fallback¶
Module fallback consists of two lowering passes that must be run as a pair. The first pass is run before freezing to place delimiters in the graph around modules that should run in PyTorch. The second pass marks nodes between these delimiters after freezing to signify they should run in PyTorch.
NotateModuleForFallback
Places delimiting nodes around module calls pre freezing to signify where in the graph nodes should run in PyTorch
MarkNodesForFallback
Looks for delimiters then marks all nodes between the delimiters to tell partitioning to run them in PyTorch
Peephole Optimize¶
The intent for this optimization pass is to catch all of the small, easy to catch peephole optimizations you might be interested in doing.
- Right now, it does:
Eliminate no-op ‘expand’ nodes
Simply x.t().t() to x
Remove Contiguous¶
Removes contiguous operators since we are doing TensorRT memory is already contiguous.
Remove Dropout¶
Removes dropout operators since we are doing inference.
Remove To¶
Removes aten::to
operators that do casting, since TensorRT manages it itself. It is important that this is one of the last passes run so that
other passes have a change to move required cast operators out of the main namespace.
Unpack AddMM¶
Unpacks aten::addmm
into aten::matmul
and aten::add_
(with an additional trt::const
op to freeze the bias in the TensorRT graph). This lets us reuse the aten::matmul
and aten::add_
converters instead of needing a dedicated converter.
Unpack LogSoftmax¶
Unpacks aten::logsoftmax
into aten::softmax
and aten::log
. This lets us reuse the
aten::softmax
and aten::log
converters instead of needing a dedicated converter.
Unroll Loops¶
Unrolls the operations of compatible loops (e.g. sufficiently short) so that you only have to go through the loop once.
Replace Tile with Repeat¶
Removes dropout operators since we are doing inference.