{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using the Minifier\n", "We have a pretty convenient test case minifier with this interface\n", "```\n", "def minifier(fail_f: fx.GraphModule, inps, module_fails):\n", " \"\"\"\n", " Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.\n", "\n", " Does 2 main strategies:\n", " 1. Truncates suffix: Removes some suffix from the graph and sets a new output.\n", " 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,\n", " tries replacing quarter of the graph, etc.\n", "\n", " >>> failing_function = fx.symbolic_trace(f)\n", " >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))\n", "\n", " note: module_fails returns True if it fails.\n", " ...\n", "```\n", "\n", "Specifically, it takes your FX graph, and tries to minify it with the following 4 strategies (while checking that the resulting graph still returns True for `module_fails`), until it can't minify it anymore.\n", "\n", "1. Truncates Suffix: Given a FX graph, it tries to remove some suffix from the graph. For example, given this:\n", "\n", "```\n", "def f(a):\n", " b = x * 2\n", " c = b + 3\n", " d = c / 4\n", " return d\n", "```\n", "It might try truncating the suffix, and get\n", "```\n", "def f(a):\n", " b = x * 2\n", " c = b + 3\n", " return c\n", "```\n", "It tries this in a binary search manner, trying to remove the last 1/2, then 3/4, 1/4 then 7/8, 5/8, 3/8...\n", "\n", "2. [Delta Debugging](https://en.wikipedia.org/wiki/Delta_debugging): Of course, removing the suffix isn't always sufficient to minify a graph. What if the error is caused by the first instruction? So, we take an approach inspired by delta debugging - we try removing intermediate nodes of the graph. Unlike with suffixes, there are still dependencies on the removed nodes. So, instead of removing them entirely, we promote them to inputs. For example, given the above example:\n", "\n", "```\n", "def f(a):\n", " b = x * 2\n", " c = b + 3\n", " d = c / 4\n", " return d\n", "```\n", "We might remove a middle node (say, c, in this case).\n", "```\n", "def f(a, c):\n", " b = x * 2\n", " d = c / 4\n", " return d\n", "```\n", "\n", "Finally, there are 2 auxiliary strategies - eliminating dead code and removing unused inputs. These are somewhat self-explanatory." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, let's take a look at a toy example. Let's pretend that our graph fails if it has a \"multiply\" in it. Let's create a failing graph." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key\n", " operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)\n", " registered at aten/src/ATen/RegisterSchema.cpp:6\n", " dispatch key: FuncTorchBatched\n", " previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10338\n", " new kernel: registered at /fsx/users/chilli/work/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Started off with 7 nodes\n", "###################\n", "Current size: 7\n", "###################\n", "Strategy: Remove suffix\n", "\n", "SUCCESS: Removed [4:7)\n", "\n", "###################\n", "Current size: 6\n", "###################\n", "Strategy: Delta Debugging\n", "SUCCESS: Removed (0:4] - Went from 2 placeholders to 4\n", "\n", "###################\n", "Current size: 6\n", "###################\n", "Strategy: Remove unused inputs\n", "SUCCESS: Went from 4 inputs to 2 inputs\n", "\n", "###################\n", "Current size: 4\n", "###################\n", "Strategy: Remove suffix\n", "FAIL: Could not remove suffix\n", "Strategy: Delta Debugging\n", "FAIL: Could not remove prefix\n", "\n", "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", "inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n", "\n", "\n", "\n", "def forward(self, div, add):\n", " mul = torch.ops.aten.mul(add, div); add = div = None\n", " return (mul,)\n", " \n", "f = torch.jit.script(forward)\n", "with torch.jit.fuser(\"fuser2\"):\n", " for _ in range(5):\n", " f(*inps)\n" ] } ], "source": [ "import torch\n", "import torch.fx as fx\n", "from functorch.compile import minifier\n", "\n", "def failing_f(x, y):\n", " y = torch.ops.aten.div(x, y)\n", " x = torch.ops.aten.add(x, 3)\n", " x = torch.ops.aten.mul(x, y)\n", " return torch.ops.aten.sub(x, y)\n", "\n", "inps = [torch.randn(3), torch.randn(3)]\n", "\n", "def pass_checker(fx_g, inps):\n", " return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))\n", "\n", "min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tada! Our graph is now a minimal example that still fails.\n", "\n", "Since the primary use case of this minifier (for now) is for NVFuser repros, we print out a string for convenience that creates a self-contained repro to run the minified graph with NVFuser.\n", "\n", "Note that in practice, we provide 2 main \"graph checkers\" - `check_nvfuser_subprocess` and `check_nvfuser_correctness_subprocess`. These are used to check for errors and correctness (i.e. do the results match eager) respectively. These can be used like\n", "\n", "```\n", "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n", "minifier(failing_graph, inps, check_nvfuser_subprocess)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, assuming you're using AOTAutograd, there's another problem - how do you obtain the FX graph in the first place to pass to the minifier? One possible way is simply to use `print_compile`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "def forward(self, primals_1):\n", " cos = torch.ops.aten.cos(primals_1)\n", " cos_1 = torch.ops.aten.cos(cos)\n", " return [cos_1, primals_1, cos]\n", " \n", "\n", "\n", "\n", "def forward(self, primals_1, cos, tangents_1):\n", " sin = torch.ops.aten.sin(cos); cos = None\n", " neg = torch.ops.aten.neg(sin); sin = None\n", " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", " sin_1 = torch.ops.aten.sin(primals_1); primals_1 = None\n", " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", " return [mul_1]\n", " \n" ] }, { "data": { "text/plain": [ "tensor([0.6062, 0.9982, 0.6474], grad_fn=)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functorch.compile import aot_function\n", "\n", "from functorch.compile import print_compile\n", "# Or...\n", "def print_compile(fx_g, _):\n", " print(fx_g.code)\n", " return fx_g\n", "\n", "def foo(x):\n", " return x.cos().cos()\n", "inp = torch.randn(3, requires_grad=True)\n", "aot_function(foo, print_compile)(inp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, this doesn't provide the inputs, nor does it handle any tensor constants that might be saved in the graph. To resolve this, we have another \"compiler\" called `debug_compile`. It simply prints out a string that can be copy pasted and run from another file. It leverages FX's `to_folder` feature to serialize the graph to disk, along with any constants.\n", "\n", "You can apply it to either the `fw_compiler` to dump the forwards graph or `bw_compiler` to dump the backwards graph." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "##############################################################\n", "# To minimize FX graph, copy and paste the below and run it #\n", "##############################################################\n", "\n", "import torch\n", "import torch.fx as fx\n", "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n", "\n", "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", "inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n", "from foo import FxModule\n", "mod = FxModule().cuda()\n", "\n", "with torch.jit.fuser(\"fuser2\"):\n", " # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess\n", " minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)\n", "\n" ] }, { "data": { "text/plain": [ "tensor([0.6062, 0.9982, 0.6474], grad_fn=)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functorch.compile import memory_efficient_fusion, debug_compile\n", "\n", "memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, let's copy paste it and see how it works - note that I made a couple minor modifications to run on CPU and use the previous \"graph fails if there's a multiply in it\" checker." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Started off with 10 nodes\n", "###################\n", "Current size: 10\n", "###################\n", "Strategy: Remove suffix\n", "\n", "SUCCESS: Removed [6:10)\n", "\n", "###################\n", "Current size: 8\n", "###################\n", "Strategy: Delta Debugging\n", "SUCCESS: Removed (0:4] - Went from 2 placeholders to 4\n", "\n", "###################\n", "Current size: 8\n", "###################\n", "Strategy: Remove unused inputs\n", "SUCCESS: Went from 4 inputs to 3 inputs\n", "\n", "###################\n", "Current size: 7\n", "###################\n", "Strategy: Remove suffix\n", "\n", "SUCCESS: Removed [4:7)\n", "\n", "###################\n", "Current size: 6\n", "###################\n", "Strategy: Remove unused inputs\n", "SUCCESS: Went from 3 inputs to 2 inputs\n", "\n", "###################\n", "Current size: 5\n", "###################\n", "Strategy: Delta Debugging\n", "SUCCESS: Removed (2:3] - Went from 2 placeholders to 3\n", "\n", "###################\n", "Current size: 5\n", "###################\n", "Strategy: Remove unused inputs\n", "SUCCESS: Went from 3 inputs to 2 inputs\n", "\n", "###################\n", "Current size: 4\n", "###################\n", "Strategy: Remove suffix\n", "FAIL: Could not remove suffix\n", "Strategy: Delta Debugging\n", "FAIL: Could not remove prefix\n", "\n", "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", "inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n", "\n", "\n", "\n", "def forward(self, tangents_1, neg):\n", " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", " return (mul,)\n", " \n", "f = torch.jit.script(forward)\n", "with torch.jit.fuser(\"fuser2\"):\n", " for _ in range(5):\n", " f(*inps)\n" ] }, { "data": { "text/plain": [ "(GraphModule(), [tensor([1., 1., 1.]), tensor([-0.5144, -0.5144, -0.5144])])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.fx as fx\n", "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n", "\n", "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", "inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]\n", "from foo import FxModule\n", "mod = FxModule()\n", "\n", "minifier(fx.symbolic_trace(mod), inps, pass_checker)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hopefully that was useful :)" ] } ], "metadata": { "interpreter": { "hash": "a1cf69278e4496ab232105d2fffcc75678d2dcbec1c795483197519eb80161c7" }, "kernelspec": { "display_name": "Python 3.8.12 ('py38')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }