functorch.compile.aot_module¶
- functorch.compile.aot_module(mod, *args, **kwargs)[source]¶
Traces the forward and backward graph of
mod
using torch dispatch tracing mechanism. It is wrapper function, that underneath usesaot_function()
to perform tracing and compilation.aot_module()
lifts the parameters and buffers ofnn.Module
as inputs to a new callable which is then compiled throughaot_function()
.Warning
This API is experimental and likely to change.
- Parameters
mod (Callable) – A
nn.Module
module.args – args to be passed to
aot_function()
kwargs – kwargs to be passed to
aot_function()
- Returns
Returns a
nn.Module
that retains the eager behavior of the originalmod
, but with forward and backward graph compiled.