functorch.compile.aot_function¶
-
functorch.compile.
aot_function
(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, hasher_type='StaticShapeHasher', static_argnums=None)[source]¶ Traces the forward and backward graph of
fn
using torch dispatch mechanism, and then compiles the generated forward and backward graphs throughfw_compiler
andbw_compiler
.aot_function()
traces the forward and backward graph ahead of time, and generates a joint forward and backward graph.partition_fn
is then used to separate out forward and backward graphs. The partitioner function can be used to perform optimizations such as recomputation. One can set decompositions dictionary to decompose the operators into a sequence of core or simpler operators supported by the backend compilers.aot_function()
uses a compilation cache, based on input tensor properties, to detect when there is a need of recompilation. By default, its behavior is static, i.e., it recompiles if shape of any input tensor changes.static_argnums
allows user to mark the arguments of the originalfn
as static. This is useful when an argument is a non-tensor, e.g.,int
orbool
. A change in the actual value of static arg causes recompilation.Warning
This API is experimental and likely to change.
- Parameters
fn (Callable) – A Python function that takes one ore more arguments. Must return one or more Tensors.
fw_compiler (Callable) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph.
bw_compiler (Optional[Callable]) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. Default: None (when None, it defaults to the
fw_compiler
)partition_fn (Callable) – A Python function that takes a joint forward and backward graph, and partitions it into separate forward and backward graphs.
decompositions (Dict) – A dictionary to define the decomposition of larger Aten ops into simpler or core Aten ops.
static_argnums (Optional[Tuple[Int]]) – An option tuple of ints to mark the arguments of the function as static.
- Returns
Returns a
Callable
that retains the eager behavior of the originalfn
, but with forward and backward graph compiled viafw_compile
andbw_compile
.
A simple example usage of
aot_function()
is as follows. This example will print the forward and backward graphs of the functionfn
>>> fn = lambda x : x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module >>> aot_fn = aot_function(fn, print_compile_fn) >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x)
The static argnums are used to mark the non-tensor arguments as static. An example is as follows where the dropout probability is as argument to the original function.
>>> def fn(input, bias, residual, p: float): >>> a = torch.add(input, bias) >>> b = torch.nn.functional.dropout(a, p, training=True) >>> c = b + residual >>> return c >>> aot_fn = aot_function(fn, print_compile_fn, static_argnums=(3,))