functorch.compile.min_cut_rematerialization_partition¶
- functorch.compile.min_cut_rematerialization_partition(joint_module, _joint_inputs, compiler='inductor', *, num_fwd_outputs)[source]¶
Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation.
To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimination.
Warning
This API is experimental and likely to change.
- Parameters
joint_module (fx.GraphModule) – The joint forward and backward graph. This is the result of AOT Autograd tracing.
_joint_inputs – The inputs to the joint graph. This is unused.
compiler – This option determines the default set of recomputable ops. Currently, there are two options:
nvfuser
andinductor
.recomputable_ops – This is an optional set of recomputable ops. If this is not None, then this set of ops will be used instead of the default set of ops.
num_fwd_outputs – The number of outputs from the forward graph.
- Returns
Returns the generated forward and backward Fx graph modules.