Shortcuts

functorch.compile.min_cut_rematerialization_partition

functorch.compile.min_cut_rematerialization_partition(joint_module, _joint_inputs)[source]

Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation.

To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimintation.

Warning

This API is experimental and likely to change.

Parameters

joint_module (fx.GraphModule) – The joint forward and backward graph. This is the result of AOT Autograd tracing.

Returns

Returns the generated forward and backward Fx graph modules.