Custom Backends =============== Overview -------- ``torch.compile`` provides a straightforward method to enable users to define custom backends. A backend function has the contract ``(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable``. Backend functions can be called by TorchDynamo, the graph tracing component of ``torch.compile``, after tracing an FX graph and are expected to return a compiled function that is equivalent to the traced FX graph. The returned callable should have the same contract as the ``forward`` function of the original ``torch.fx.GraphModule`` passed into the backend: ``(*args: torch.Tensor) -> List[torch.Tensor]``. In order for TorchDynamo to call your backend, pass your backend function as the ``backend`` kwarg in ``torch.compile``. For example, .. code-block:: python import torch def my_custom_backend(gm, example_inputs): return gm.forward def f(...): ... f_opt = torch.compile(f, backend=my_custom_backend) @torch.compile(backend=my_custom_backend) def g(...): ... See below for more examples. Registering Custom Backends --------------------------- You can register your backend using the ``register_backend`` decorator, for example, .. code-block:: python from torch._dynamo.optimizations import register_backend @register_backend def my_compiler(gm, example_inputs): ... Besides the ``register_backend`` decorator, if your backend is in another python package, you could also register your backend through entry points of python package, which provides a way for a package to register a plugin for another one. .. hint:: You can learn more about ``entry_points`` in the `python packaging documentation `__. To register your backend through ``entry_points``, you could add your backend function to the ``torch_dynamo_backends`` entry point group in the ``setup.py`` file of your package like: .. code-block:: python ... setup( ... 'torch_dynamo_backends': [ 'my_compiler = your_module.submodule:my_compiler', ] ... ) Please replace the ``my_compiler`` before ``=`` to the name of your backend's name and replace the part after ``=`` to the module and function name of your backend function. The entry point will be added to your python environment after the installation of the package. When you call ``torch.compile(model, backend="my_compiler")``, PyTorch would first search the backend named ``my_compiler`` that has been registered with ``register_backend``. If not found, it will continue to search in all backends registered via ``entry_points``. Registration serves two purposes: * You can pass a string containing your backend function's name to ``torch.compile`` instead of the function itself, for example, ``torch.compile(model, backend="my_compiler")``. * It is required for use with the `minifier `__. Any generated code from the minifier must call your code that registers your backend function, typically through an ``import`` statement. Custom Backends after AOTAutograd --------------------------------- It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo. This is useful for 2 main reasons: * Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation. * AOTAutograd produces FX graphs consisting of `canonical Aten ops `__. As a result, custom backends only need to support the canonical Aten opset, which is a significantly smaller opset than the entire torch/Aten opset. Wrap your backend with ``torch._dynamo.optimizations.training.aot_autograd`` and use ``torch.compile`` with the ``backend`` kwarg as before. Backend functions wrapped by ``aot_autograd`` should have the same contract as before. Backend functions are passed to ``aot_autograd`` through the ``fw_compiler`` (forward compiler) or ``bw_compiler`` (backward compiler) kwargs. If ``bw_compiler`` is not specified, the backward compile function defaults to the forward compile function. One caveat is that AOTAutograd requires compiled functions returned by backends to be "boxed". This can be done by wrapping the compiled function with ``functorch.compile.make_boxed_func``. For example, .. code-block:: python from torch._dynamo.optimizations.training import aot_autograd from functorch.compile import make_boxed_func def my_compiler(gm, example_inputs): return make_boxed_func(gm.forward) my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler model_opt = torch.compile(model, backend=my_backend) Examples -------- Debugging Backend ^^^^^^^^^^^^^^^^^ If you want to better understand what is going on during a compilation, you can create a custom compiler, which is referred to as backend in this section, that will print pretty print the fx ``GraphModule`` extracted from Dynamo’s bytecode analysis and return a ``forward()`` callable. For example: .. code-block:: python from typing import List import torch def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable @torch.compile(backend=my_compiler) def fn(x, y): a = torch.cos(x) b = torch.sin(y) return a + b fn(torch.randn(10), torch.randn(10)) Running the above example produces the following output: :: my_compiler() called with FX graph: opcode name target args kwargs ------------- ------ ------------------------------------------------------ ---------- -------- placeholder x x () {} placeholder y y () {} call_function cos (x,) {} call_function sin (y,) {} call_function add (cos, sin) {} output output output ((add,),) {} This works for ``torch.nn.Module`` as well as shown below: .. code-block:: python from typing import List import torch def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(torch.cos(x)) mod = MockModule() optimized_mod = torch.compile(mod, backend=my_compiler) optimized_mod(torch.randn(10)) Let’s take a look at one more example with control flow: .. code-block:: python from typing import List import torch def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable @torch.compile(backend=my_compiler) def toy_example(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b for _ in range(100): toy_example(torch.randn(10), torch.randn(10)) Running this example produces the following output: :: my_compiler() called with FX graph: opcode name target args kwargs ------------- ------- ------------------------------------------------------ ---------------- -------- placeholder a a () {} placeholder b b () {} call_function abs_1 (a,) {} call_function add (abs_1, 1) {} call_function truediv (a, add) {} call_method sum_1 sum (b,) {} call_function lt (sum_1, 0) {} output output output ((truediv, lt),) {} my_compiler() called with FX graph: opcode name target args kwargs ------------- ------ ----------------------- ----------- -------- placeholder b b () {} placeholder x x () {} call_function mul (b, -1) {} call_function mul_1 (x, mul) {} output output output ((mul_1,),) {} my_compiler() called with FX graph: opcode name target args kwargs ------------- ------ ----------------------- --------- -------- placeholder b b () {} placeholder x x () {} call_function mul (x, b) {} output output output ((mul,),) {} The order of the last two graphs is nondeterministic depending on which one is encountered first by the just-in-time compiler. Speedy Backend ^^^^^^^^^^^^^^ Integrating a custom backend that offers superior performance is also easy and we’ll integrate a real one with `optimize_for_inference `__: .. code-block:: python def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): scripted = torch.jit.script(gm) return torch.jit.optimize_for_inference(scripted) And then you should be able to optimize any existing code with: .. code-block:: python @torch.compile(backend=optimize_for_inference_compiler) def code_to_accelerate(): ... Composable Backends ^^^^^^^^^^^^^^^^^^^ TorchDynamo includes many backends, which can be found in `backends.py `__ or ``torch._dynamo.list_backends()``. You can combine these backends together with the following code: .. code-block:: python from torch._dynamo.optimizations import BACKENDS def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): try: trt_compiled = BACKENDS["tensorrt"](gm, example_inputs) if trt_compiled is not None: return trt_compiled except Exception: pass # first backend failed, try something else... try: inductor_compiled = BACKENDS["inductor"](gm, example_inputs) if inductor_compiled is not None: return inductor_compiled except Exception: pass return gm.forward