Shortcuts

functorch.compile.default_partition

functorch.compile.default_partition(joint_module, _joint_inputs)[source]

Partitions the joint_module in a manner that closely resembles the behavior observed in the original .forward() and .backward() of the callable, i.e., the resulting forward graph contains those operators that are executed in the original .forward() callable passed to aot_function().

The default partitioner collects the operators that are between the forward inputs and the forward outputs. This helps in finding the tensors which have to be stashed for the backward pass. These stashed tensors become the output of the generated forward graph. The remaining operators are then placed in the backward graph.

Warning

This API is experimental and likely to change.

Parameters

joint_module (fx.GraphModule) – The joint forward and backward graph. This is the result of AOT Autograd tracing.

Returns

Returns the generated forward and backward Fx graph modules.