functorch.compile.default_partition¶
- functorch.compile.default_partition(joint_module, _joint_inputs, *, num_fwd_outputs)[source]¶
Partitions the
joint_module
in a manner that closely resembles the behavior observed in the original.forward()
and.backward()
of the callable, i.e., the resulting forward graph contains those operators that are executed in the original.forward()
callable passed toaot_function()
.The default partitioner collects the operators that are between the forward inputs and the forward outputs. This helps in finding the tensors which have to be stashed for the backward pass. These stashed tensors become the output of the generated forward graph. The remaining operators are then placed in the backward graph.
Warning
This API is experimental and likely to change.
- Parameters
joint_module (fx.GraphModule) – The joint forward and backward graph. This is the result of AOT Autograd tracing.
- Returns
Returns the generated forward and backward Fx graph modules.