Shortcuts

functorch.compile.min_cut_rematerialization_partition

functorch.compile.min_cut_rematerialization_partition(joint_module, _joint_inputs, compiler='nvfuser')[source]

Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation.

To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimintation.

Warning

This API is experimental and likely to change.

Parameters

joint_module (fx.GraphModule) – The joint forward and backward graph. This is the result of AOT Autograd tracing.

Returns

Returns the generated forward and backward Fx graph modules.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources