functorch.compile.memory_efficient_fusion¶
-
functorch.compile.
memory_efficient_fusion
(fn, static_argnums=None)[source]¶ Wrapper function over
aot_function()
andaot_module()
to perform memory efficient fusion. It uses themin_cut_rematerialization_partition()
partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs.Warning
This API is experimental and likely to change.
- Parameters
fn (Union[Callable, nn.Module]) – A Python function or a
nn.Module
that takes one ore more arguments. Must return one or more Tensors.static_argnums (Optional[Tuple[Int]]) – An option tuple of ints to mark the arguments of the function as static.
- Returns
Returns a
Callable
ornn.Module
that retains the eager behavior of the originalfn
, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser.