Shortcuts

functorch.compile.aot_function

functorch.compile.aot_function(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, hasher_type='StaticShapeHasher', static_argnums=None)[source]

Traces the forward and backward graph of fn using torch dispatch mechanism, and then compiles the generated forward and backward graphs through fw_compiler and bw_compiler.

aot_function() traces the forward and backward graph ahead of time, and generates a joint forward and backward graph. partition_fn is then used to separate out forward and backward graphs. The partitioner function can be used to perform optimizations such as recomputation. One can set decompositions dictionary to decompose the operators into a sequence of core or simpler operators supported by the backend compilers.

aot_function() uses a compilation cache, based on input tensor properties, to detect when there is a need of recompilation. By default, its behavior is static, i.e., it recompiles if shape of any input tensor changes.

static_argnums allows user to mark the arguments of the original fn as static. This is useful when an argument is a non-tensor, e.g., int or bool. A change in the actual value of static arg causes recompilation.

Warning

This API is experimental and likely to change.

Parameters
  • fn (Callable) – A Python function that takes one ore more arguments. Must return one or more Tensors.

  • fw_compiler (Callable) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph.

  • bw_compiler (Optional[Callable]) – A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. Default: None (when None, it defaults to the fw_compiler)

  • partition_fn (Callable) – A Python function that takes a joint forward and backward graph, and partitions it into separate forward and backward graphs.

  • decompositions (Dict) – A dictionary to define the decomposition of larger Aten ops into simpler or core Aten ops.

  • static_argnums (Optional[Tuple[Int]]) – An option tuple of ints to mark the arguments of the function as static.

Returns

Returns a Callable that retains the eager behavior of the original fn, but with forward and backward graph compiled via fw_compile and bw_compile.

A simple example usage of aot_function() is as follows. This example will print the forward and backward graphs of the function fn

>>> fn = lambda x : x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>>     print(fx_module)
>>>     return fx_module
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)

The static argnums are used to mark the non-tensor arguments as static. An example is as follows where the dropout probability is as argument to the original function.

>>> def fn(input, bias, residual, p: float):
>>>     a = torch.add(input, bias)
>>>     b = torch.nn.functional.dropout(a, p, training=True)
>>>     c = b + residual
>>>     return c
>>> aot_fn = aot_function(fn, print_compile_fn, static_argnums=(3,))