functorch.compile.memory_efficient_fusion(fn, static_argnums=None)[source]

Wrapper function over aot_function() and aot_module() to perform memory efficient fusion. It uses the min_cut_rematerialization_partition() partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs.


This API is experimental and likely to change.

  • fn (Union[Callable, nn.Module]) – A Python function or a nn.Module that takes one ore more arguments. Must return one or more Tensors.

  • static_argnums (Optional[Tuple[Int]]) – An option tuple of ints to mark the arguments of the function as static.


Returns a Callable or nn.Module that retains the eager behavior of the original fn, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser.