functorch.compile.aot_function¶
-
functorch.compile.
aot_function
(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, hasher_type=None, static_argnums=None)[source]¶ Traces the forward and backward graph of
fn
using torch dispatch mechanism, and then compiles the generated forward and backward graphs throughfw_compiler
andbw_compiler
.aot_function()
traces the forward and backward graph ahead of time, and generates a joint forward and backward graph.partition_fn
is then used to separate out forward and backward graphs. The partitioner function can be used to perform optimizations such as recomputation. One can set decompositions dictionary to decompose the operators into a sequence of core or simpler operators supported by the backend compilers.aot_function()
uses a compilation cache, based on input tensor properties, to detect when there is a need of recompilation.Warning
This API is experimental and likely to change.
- Parameters
fn (Callable) – A Python function that takes one ore more arguments. Must return one or more Tensors.
fw_compiler (Callable) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph.
bw_compiler (Optional[Callable]) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. Default: None (when None, it defaults to the
fw_compiler
)partition_fn (Callable) – A Python function that takes a joint forward and backward graph, and partitions it into separate forward and backward graphs.
decompositions (Dict) – A dictionary to define the decomposition of larger Aten ops into simpler or core Aten ops.
- Returns
Returns a
Callable
that retains the eager behavior of the originalfn
, but with forward and backward graph compiled viafw_compile
andbw_compile
.
A simple example usage of
aot_function()
is as follows. This example will print the forward and backward graphs of the functionfn
>>> fn = lambda x : x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module >>> aot_fn = aot_function(fn, print_compile_fn) >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x)