{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n", "# https://pytorch.org/tutorials/beginner/colab\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(beta) Building a Simple CPU Performance Profiler with FX\n", "=========================================================\n", "\n", "**Author**: [James Reed](https://github.com/jamesr66a)\n", "\n", "In this tutorial, we are going to use FX to do the following:\n", "\n", "1) Capture PyTorch Python code in a way that we can inspect and gather\n", " statistics about the structure and execution of the code\n", "2) Build out a small class that will serve as a simple performance\n", " \\\"profiler\\\", collecting runtime statistics about each part of the\n", " model from actual runs.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this tutorial, we are going to use the torchvision ResNet18 model\n", "for demonstration purposes.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n", "import torch.fx\n", "import torchvision.models as models\n", "\n", "rn18 = models.resnet18()\n", "rn18.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have our model, we want to inspect deeper into its\n", "performance. That is, for the following invocation, which parts of the\n", "model are taking the longest?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "input = torch.randn(5, 3, 224, 224)\n", "output = rn18(input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A common way of answering that question is to go through the program\n", "source, add code that collects timestamps at various points in the\n", "program, and compare the difference between those timestamps to see how\n", "long the regions between the timestamps take.\n", "\n", "That technique is certainly applicable to PyTorch code, however it would\n", "be nicer if we didn\\'t have to copy over model code and edit it,\n", "especially code we haven\\'t written (like this torchvision model).\n", "Instead, we are going to use FX to automate this \\\"instrumentation\\\"\n", "process without needing to modify any source.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let\\'s get some imports out of the way (we will be using all of\n", "these later in the code).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import statistics, tabulate, time\n", "from typing import Any, Dict, List\n", "from torch.fx import Interpreter" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
tabulate
is an external library that is not a dependency of PyTorch.We will be using it to more easily visualize performance data. Pleasemake sure you've installed it from your favorite Python package source.
We use Python's time.time
function to pull wall clocktimestamps and compare them. This is not the most accurateway to measure performance, and will only give us a first-order approximation. We use this simple technique only for thepurpose of demonstration in this tutorial.