Shortcuts

functorch.compile.memory_efficient_fusion

functorch.compile.memory_efficient_fusion(fn, static_argnums=None, **kwargs)[source]

Wrapper function over aot_function() and aot_module() to perform memory efficient fusion. It uses the min_cut_rematerialization_partition() partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs.

Warning

This API is experimental and likely to change.

Parameters
  • fn (Union[Callable, nn.Module]) – A Python function or a nn.Module that takes one ore more arguments. Must return one or more Tensors.

  • static_argnums (Optional[Tuple[Int]]) – An option tuple of ints to mark the arguments of the function as static.

  • **kwargs – Any other overrides you want to make to the settings

Returns

Returns a Callable or nn.Module that retains the eager behavior of the original fn, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources