functorch.compile.memory_efficient_fusion¶
- functorch.compile.memory_efficient_fusion(fn, **kwargs)[source]¶
Wrapper function over
aot_function()
andaot_module()
to perform memory efficient fusion. It uses themin_cut_rematerialization_partition()
partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs.Warning
This API is experimental and likely to change.
- Parameters
fn (Union[Callable, nn.Module]) – A Python function or a
nn.Module
that takes one ore more arguments. Must return one or more Tensors.**kwargs – Any other overrides you want to make to the settings
- Returns
Returns a
Callable
ornn.Module
that retains the eager behavior of the originalfn
, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser.