functorch.compile.min_cut_rematerialization_partition¶
-
functorch.compile.
min_cut_rematerialization_partition
(joint_module, _joint_inputs, compiler='nvfuser')[source]¶ Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation.
To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimintation.
Warning
This API is experimental and likely to change.
- Parameters
joint_module (fx.GraphModule) – The joint forward and backward graph. This is the result of AOT Autograd tracing.
- Returns
Returns the generated forward and backward Fx graph modules.