Shortcuts

functorch.compile.min_cut_rematerialization_partition

functorch.compile.min_cut_rematerialization_partition(joint_module, _joint_inputs, compiler='inductor', recomputable_ops=None, *, num_fwd_outputs)[source]

Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation.

To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimination.

Warning

This API is experimental and likely to change.

Parameters
  • joint_module (fx.GraphModule) – The joint forward and backward graph. This is the result of AOT Autograd tracing.

  • _joint_inputs – The inputs to the joint graph. This is unused.

  • compiler – This option determines the default set of recomputable ops. Currently, there are two options: nvfuser and inductor.

  • recomputable_ops – This is an optional set of recomputable ops. If this is not None, then this set of ops will be used instead of the default set of ops.

  • num_fwd_outputs – The number of outputs from the forward graph.

Returns

Returns the generated forward and backward Fx graph modules.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources