Shortcuts

functorch.compile.aot_function

functorch.compile.aot_function(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, hasher_type=None, static_argnums=None)[source]

Traces the forward and backward graph of fn using torch dispatch mechanism, and then compiles the generated forward and backward graphs through fw_compiler and bw_compiler.

aot_function() traces the forward and backward graph ahead of time, and generates a joint forward and backward graph. partition_fn is then used to separate out forward and backward graphs. The partitioner function can be used to perform optimizations such as recomputation. One can set decompositions dictionary to decompose the operators into a sequence of core or simpler operators supported by the backend compilers.

aot_function() uses a compilation cache, based on input tensor properties, to detect when there is a need of recompilation.

Warning

This API is experimental and likely to change.

Parameters
  • fn (Callable) – A Python function that takes one ore more arguments. Must return one or more Tensors.

  • fw_compiler (Callable) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph.

  • bw_compiler (Optional[Callable]) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. Default: None (when None, it defaults to the fw_compiler)

  • partition_fn (Callable) – A Python function that takes a joint forward and backward graph, and partitions it into separate forward and backward graphs.

  • decompositions (Dict) – A dictionary to define the decomposition of larger Aten ops into simpler or core Aten ops.

Returns

Returns a Callable that retains the eager behavior of the original fn, but with forward and backward graph compiled via fw_compile and bw_compile.

A simple example usage of aot_function() is as follows. This example will print the forward and backward graphs of the function fn

>>> fn = lambda x : x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>>     print(fx_module)
>>>     return fx_module
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources