Note
Click here to download the full example code
(beta) Building a Simple CPU Performance Profiler with FX¶
Author: James Reed
In this tutorial, we are going to use FX to do the following:
Capture PyTorch Python code in a way that we can inspect and gather statistics about the structure and execution of the code
Build out a small class that will serve as a simple performance “profiler”, collecting runtime statistics about each part of the model from actual runs.
For this tutorial, we are going to use the torchvision ResNet18 model for demonstration purposes.
import torch
import torch.fx
import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
Now that we have our model, we want to inspect deeper into its performance. That is, for the following invocation, which parts of the model are taking the longest?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
A common way of answering that question is to go through the program source, add code that collects timestamps at various points in the program, and compare the difference between those timestamps to see how long the regions between the timestamps take.
That technique is certainly applicable to PyTorch code, however it would be nicer if we didn’t have to copy over model code and edit it, especially code we haven’t written (like this torchvision model). Instead, we are going to use FX to automate this “instrumentation” process without needing to modify any source.
First, let’s get some imports out of the way (we will be using all of these later in the code).
import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter
Note
tabulate
is an external library that is not a dependency of PyTorch.
We will be using it to more easily visualize performance data. Please
make sure you’ve installed it from your favorite Python package source.
Capturing the Model with Symbolic Tracing¶
Next, we are going to use FX’s symbolic tracing mechanism to capture the definition of our model in a data structure we can manipulate and examine.
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
%x : torch.Tensor [#users=1] = placeholder[target=x]
%conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
%relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
%maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
%layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
%layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
%layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
%layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
%layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
%layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
%layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
%layer1_1_relu_1 : [#users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%layer2_0_conv1 : [#users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_bn1 : [#users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_relu : [#users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
%layer2_0_conv2 : [#users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
%layer2_0_bn2 : [#users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [#users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_downsample_1 : [#users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [#users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
%layer2_0_relu_1 : [#users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%layer2_1_conv1 : [#users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_bn1 : [#users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_relu : [#users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
%layer2_1_conv2 : [#users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
%layer2_1_bn2 : [#users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [#users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
%layer2_1_relu_1 : [#users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%layer3_0_conv1 : [#users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_bn1 : [#users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_relu : [#users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
%layer3_0_conv2 : [#users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
%layer3_0_bn2 : [#users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [#users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_downsample_1 : [#users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [#users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
%layer3_0_relu_1 : [#users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%layer3_1_conv1 : [#users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_bn1 : [#users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_relu : [#users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
%layer3_1_conv2 : [#users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
%layer3_1_bn2 : [#users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [#users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
%layer3_1_relu_1 : [#users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%layer4_0_conv1 : [#users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_bn1 : [#users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_relu : [#users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
%layer4_0_conv2 : [#users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
%layer4_0_bn2 : [#users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [#users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_downsample_1 : [#users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [#users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
%layer4_0_relu_1 : [#users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%layer4_1_conv1 : [#users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_bn1 : [#users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_relu : [#users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
%layer4_1_conv2 : [#users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
%layer4_1_bn2 : [#users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [#users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
%layer4_1_relu_1 : [#users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
%flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
%fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
return fc
This gives us a Graph representation of the ResNet18 model. A Graph
consists of a series of Nodes connected to each other. Each Node
represents a call-site in the Python code (whether to a function,
a module, or a method) and the edges (represented as args
and kwargs
on each node) represent the values passed between these call-sites. More
information about the Graph representation and the rest of FX’s APIs ca
be found at the FX documentation https://pytorch.org/docs/master/fx.html.
Creating a Profiling Interpreter¶
Next, we are going to create a class that inherits from torch.fx.Interpreter
.
Though the GraphModule
that symbolic_trace
produces compiles Python code
that is run when you call a GraphModule
, an alternative way to run a
GraphModule
is by executing each Node
in the Graph
one by one. That is
the functionality that Interpreter
provides: It interprets the graph node-
by-node.
By inheriting from Interpreter
, we can override various functionality and
install the profiling behavior we want. The goal is to have an object to which
we can pass a model, invoke the model 1 or more times, then get statistics about
how long the model and each part of the model took during those runs.
Let’s define our ProfilingInterpreter
class:
class ProfilingInterpreter(Interpreter):
def __init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)
# We are going to store away two things here:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
def run(self, *args) -> Any:
# Record the time we started running the model
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
return_val = super().run(*args)
# Record the time we finished running the model
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
def run_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
return_val = super().run_node(n)
# Record the time we finished running the op
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
def summary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# For each node, record summary statistics
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
Note
We use Python’s time.time
function to pull wall clock
timestamps and compare them. This is not the most accurate
way to measure performance, and will only give us a first-
order approximation. We use this simple technique only for the
purpose of demonstration in this tutorial.
Investigating the Performance of ResNet18¶
We can now use ProfilingInterpreter
to inspect the performance
characteristics of our ResNet18 model;
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime
------------- --------------------- --------------------- -------------------
call_module conv1 0.00563812 9.00016
call_module maxpool 0.00538921 8.60282
call_module layer1_0_conv1 0.00377941 6.03309
call_module layer1_0_conv2 0.00371599 5.93185
call_module layer1_1_conv2 0.00360036 5.74727
call_module layer1_1_conv1 0.00344419 5.49798
call_module layer2_0_conv2 0.00329566 5.26087
call_module layer2_1_conv1 0.00286651 4.57582
call_module layer4_1_conv2 0.00284338 4.5389
call_module layer4_1_conv1 0.0027442 4.38057
call_module layer4_0_conv2 0.00272536 4.35051
call_module layer2_0_conv1 0.00214124 3.41807
call_module layer2_1_conv2 0.0020833 3.32558
call_module layer3_1_conv2 0.0020051 3.20075
call_module layer3_1_conv1 0.0019815 3.16307
call_module layer3_0_conv2 0.00196505 3.13681
call_module layer4_0_conv1 0.00163722 2.6135
call_module bn1 0.00125742 2.00722
call_module layer3_0_conv1 0.00118566 1.89267
call_module layer2_0_downsample_0 0.00070858 1.13111
call_module layer3_0_downsample_0 0.000405073 0.64662
call_module layer4_0_downsample_0 0.000398397 0.635963
call_module relu 0.000377655 0.602852
call_function add 0.000367641 0.586867
call_function add_1 0.000319242 0.509608
call_module layer1_1_bn2 0.000285149 0.455184
call_module layer1_0_bn2 0.000252247 0.402663
call_module fc 0.00024271 0.387439
call_module layer1_0_bn1 0.000221729 0.353947
call_module layer1_1_bn1 0.000221252 0.353186
call_function add_2 0.000217676 0.347477
call_function add_3 0.000173807 0.277449
call_module layer2_0_downsample_1 0.000154495 0.246621
call_module layer2_1_bn2 0.000136852 0.218458
call_module layer2_0_bn1 0.00013423 0.214271
call_module avgpool 0.000125408 0.20019
call_module layer1_0_relu 0.00011611 0.185347
call_module layer1_0_relu_1 0.00011611 0.185347
call_module layer3_1_bn2 0.000114441 0.182682
call_module layer1_1_relu_1 0.00011301 0.180399
call_module layer2_0_bn2 0.00011158 0.178115
call_module layer4_1_bn2 0.000101328 0.16175
call_module layer1_1_relu 9.799e-05 0.156422
call_module layer2_1_bn1 9.58443e-05 0.152997
call_module layer4_0_downsample_1 9.08375e-05 0.145004
call_module layer3_0_bn2 8.98838e-05 0.143482
call_module layer4_0_bn2 8.86917e-05 0.141579
call_module layer4_1_bn1 8.60691e-05 0.137392
call_module layer3_1_bn1 8.01086e-05 0.127878
call_module layer3_0_downsample_1 7.96318e-05 0.127117
call_function add_7 7.89165e-05 0.125975
call_function add_5 7.7486e-05 0.123691
call_module layer3_0_bn1 7.65324e-05 0.122169
call_module layer4_0_bn1 7.41482e-05 0.118363
call_module layer2_0_relu_1 6.79493e-05 0.108468
call_module layer2_0_relu 6.74725e-05 0.107707
call_function add_4 6.41346e-05 0.102378
call_function add_6 6.36578e-05 0.101617
call_module layer4_1_relu 5.88894e-05 0.0940054
call_module layer2_1_relu_1 5.26905e-05 0.0841101
call_module layer4_0_relu 5.05447e-05 0.0806848
call_module layer4_0_relu_1 4.69685e-05 0.0749759
call_module layer2_1_relu 4.62532e-05 0.0738342
call_module layer4_1_relu_1 4.48227e-05 0.0715506
call_module layer3_0_relu_1 4.24385e-05 0.0677447
call_module layer3_1_relu 4.14848e-05 0.0662224
call_module layer3_0_relu 4.07696e-05 0.0650806
call_module layer3_1_relu_1 4.00543e-05 0.0639389
call_function flatten 3.55244e-05 0.0567077
placeholder x 2.09808e-05 0.0334918
output output 1.38283e-05 0.0220741
There are two things we should call out here:
MaxPool2d
takes up the most time. This is a known issue: https://github.com/pytorch/pytorch/issues/51393BatchNorm2d also takes up significant time. We can continue this line of thinking and optimize this in the Conv-BN Fusion with FX tutorial.
Conclusion¶
As we can see, using FX we can easily capture PyTorch programs (even ones we don’t have the source code for!) in a machine-interpretable format and use that for analysis, such as the performance analysis we’ve done here. FX opens up an exciting world of possibilities for working with PyTorch programs.
Finally, since FX is still in beta, we would be happy to hear any feedback you have about using it. Please feel free to use the PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker (https://github.com/pytorch/pytorch/issues) to provide any feedback you might have.
Total running time of the script: ( 0 minutes 0.332 seconds)