import torch
import torch.nn as nn
from torch import Tensor
from functorch import make_fx
from torch.fx import immutable_collections
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch.nn.utils import _stateless
from functorch._C import CompileCache
from .decompositions import register_decomposition
from .partitioners import default_partition
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional

    lambda x: (list(x), None),
    lambda x, c: immutable_collections.immutable_list(x),
    lambda x: (list(x.values()), list(x.keys())),
    lambda x, c: immutable_collections.immutable_dict(
        {key: value for key, value in zip(c, x)}

# TODO - move this to PyTorch core. This overrides the pytree implementation for
# dict to maintain parity with Deepmind pytree.
Context = Any

def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
    keys = list(sorted(d.keys()))
    values = [d[key] for key in keys]
    return values, keys

def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
    return {key: value for key, value in zip(context, values)}

pytree._register_pytree_node(dict, _dict_flatten, _dict_unflatten)

aten = torch.ops.aten

def create_joint_forward_backward(fn):
    def joint_forward_backward(
        primals: List[Any], tangents: List[Any]
    ) -> Tuple[List[Any], List[Any]]:
        # Call the forward pass
        outs = fn(*primals)
        # Get the inputs that need gradients
        grad_primals = []
        inputs_needs_grads = []
        for p in primals:
            is_grad_tensor = isinstance(p, Tensor) and p.requires_grad
            if is_grad_tensor:

        # Get the outputs that need gradients
        assert len(tangents) == len(outs)
        needed_outs = []
        needed_tangents = []
        for out, tangent in zip(outs, tangents):
            if isinstance(out, Tensor) and out.requires_grad:
        backward_out = []
        # Call the backwards pass
        if grad_primals:
            backward_out = torch.autograd.grad(
        backward_out_iter = iter(backward_out)
        return outs, [
            next(backward_out_iter) if i else None for i in inputs_needs_grads

    return joint_forward_backward

def normalize_as_list(x):
    if isinstance(x, tuple):
        return list(x)
    elif isinstance(x, list):
        return x
    return [x]

aot_autograd_decompositions = {}

@register_decomposition(aten.rsub, aot_autograd_decompositions)
def rsub(a, b, alpha=1):
    return -aten.sub(a, b)

@register_decomposition(aten._reshape_alias, aot_autograd_decompositions)
def _reshape_alias(x, shape, strides):
    return aten.view(x, shape)

def create_aot_autograd_function(
    flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state
    Traces the forward and backward graphs of the attr:`flat_fn` to generate a
    joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
    the tracing mechanism to understand the graph capturing details.

    The joint graph is then passed through attr:`partition_fn` to isolate the
    forward and backward portions, which are then respectively compiled via the
    provided attr:`fw_compiler` and attr:`bw_compiler`.

    The resulting compiled forward and backward graphs are then wrapped up in a
    ``torch.autograd.Function`` object.
    joint_forward_backward = create_joint_forward_backward(flat_fn)

    compiled_fw = None
    compiled_bw = None
    num_outs = None

    class CompiledFunction(torch.autograd.Function):
        def forward(ctx, *flat_tensor_args):
            nonlocal compiled_fw, compiled_bw, num_outs
            if compiled_fw is None:
                with torch.set_grad_enabled(grad_state):
                    out = flat_fn(*flat_tensor_args)
                out = pytree.tree_map(
                    lambda x: x.detach() if isinstance(x, Tensor) else x, out

                if isinstance(out, (list, tuple)):
                    num_outs = len(out)
                    num_outs = 1

                joint_inputs = (flat_tensor_args, out)
                aot_decompositions = {**aot_autograd_decompositions, **decompositions}
                with torch.set_grad_enabled(grad_state):
                    fx_g = make_fx(joint_forward_backward, aot_decompositions)(
                fw_module, bw_module = partition_fn(fx_g, joint_inputs)
                # print(fw_module.code, bw_module.code)

                compiled_fw = fw_compiler(fw_module, flat_tensor_args)
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

                bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
                compiled_bw = bw_compiler(bw_module, bw_args)
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
            return tuple(fw_outs[0:num_outs])

        def backward(ctx, *flat_args):
            # hmm... this doesn't feel right. todo
            # contiguous_args = [t.contiguous() for t in flat_args]
            contiguous_args = [t for t in flat_args]
            out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
            return tuple(out)

    return CompiledFunction

class _CompileCache(CompileCache):

# using a C++-based pytree reduces the overhead by about 50%
    import tree

    HAS_TREE = True
except ImportError:
    HAS_TREE = False
compile_cache = None

# Inspired by autodidax (thanks!)
class PytreeThunk:
    spec = None
    # These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
    is_simple = (
        None  # if the output spec is a tuple/list, we won't bother unflattening it.
    is_really_simple = None  # if the output spec is a LeafSpec

    def set(self, spec):
        assert self.spec is None or self.spec == spec
        self.spec = spec
        if type(self.spec) in [tuple, list] and all(
            [isinstance(i, pytree.LeafSpec) for i in spec.children_specs]
            self.is_simple = True
        if isinstance(self.spec, pytree.LeafSpec):
            self.is_really_simple = True

    def unflatten(self, x):
        if self.is_really_simple:
            return x[0]
        if self.is_simple:
            return x
        return pytree.tree_unflatten(x, self.spec)

def filter_tensor_and_static_args(args, static_argnums):
    Separate out the tensor and static args. Also, for the static args, store
    the hash.
    tensor_args = []
    static_args = []
    static_args_hashed = []
    for idx, arg in enumerate(args):
        if idx not in static_argnums:
    return tensor_args, static_args, static_args_hashed

def rearrange(tensor_args, static_args, static_argnums):
    Generate the args as per the original spec. static_argnums is sorted.
    tensor_index = 0
    static_index = 0
    index = 0
    args = []
    assert len(static_args) == len(static_argnums)
    while tensor_index < len(tensor_args) and static_index < len(static_args):
        if index == static_argnums[static_index]:
            static_index += 1
            tensor_index += 1

    while tensor_index < len(tensor_args):
        tensor_index += 1

    while static_index < len(static_args):
        static_index += 1

    return args

KNOWN_TYPES = [torch.Tensor, int, str, float, bool]

[docs]def aot_function( fn: Callable, fw_compiler: Callable, bw_compiler: Optional[Callable] = None, partition_fn: Callable = default_partition, decompositions: Dict = {}, hasher_type: str = "StaticShapeHasher", static_argnums: Optional[Tuple[int]] = None, ) -> Callable: """ Traces the forward and backward graph of :attr:`fn` using torch dispatch mechanism, and then compiles the generated forward and backward graphs through :attr:`fw_compiler` and :attr:`bw_compiler`. :func:`aot_function` traces the forward and backward graph ahead of time, and generates a joint forward and backward graph. :attr:`partition_fn` is then used to separate out forward and backward graphs. The partitioner function can be used to perform optimizations such as recomputation. One can set `decompositions` dictionary to decompose the operators into a sequence of core or simpler operators supported by the backend compilers. :func:`aot_function` uses a compilation cache, based on input tensor properties, to detect when there is a need of recompilation. By default, its behavior is static, i.e., it recompiles if shape of any input tensor changes. :attr:`static_argnums` allows user to mark the arguments of the original :attr:`fn` as static. This is useful when an argument is a non-tensor, e.g., ``int`` or ``bool``. A change in the actual value of static arg causes recompilation. .. warning:: This API is experimental and likely to change. Args: fn (Callable): A Python function that takes one ore more arguments. Must return one or more Tensors. fw_compiler (Callable): A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. bw_compiler (Optional[Callable]): A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. Default: None (when None, it defaults to the :attr:`fw_compiler`) partition_fn (Callable): A Python function that takes a joint forward and backward graph, and partitions it into separate forward and backward graphs. decompositions (Dict): A dictionary to define the decomposition of larger Aten ops into simpler or core Aten ops. static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark the arguments of the function as static. Returns: Returns a ``Callable`` that retains the eager behavior of the original :attr:`fn`, but with forward and backward graph compiled via :attr:`fw_compile` and :attr:`bw_compile`. A simple example usage of :func:`aot_function` is as follows. This example will print the forward and backward graphs of the function ``fn`` >>> fn = lambda x : x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module >>> aot_fn = aot_function(fn, print_compile_fn) >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x) The static argnums are used to mark the non-tensor arguments as static. An example is as follows where the dropout probability is as argument to the original function. >>> def fn(input, bias, residual, p: float): >>> a = torch.add(input, bias) >>> b = torch.nn.functional.dropout(a, p, training=True) >>> c = b + residual >>> return c >>> aot_fn = aot_function(fn, print_compile_fn, static_argnums=(3,)) """ global compile_cache if compile_cache is None: compile_cache = CompileCache() if bw_compiler is None: bw_compiler = fw_compiler cached_res = None fn_id = id(fn) fw_compiler_id = id(fw_compiler) bw_compiler_id = id(bw_compiler) if isinstance(static_argnums, int): static_argnums = [static_argnums] elif static_argnums is not None and len(static_argnums) == 0: static_argnums = None elif static_argnums is not None: static_argnums = list(static_argnums) static_argnums.sort() def returned_function(*args, **kwargs): global compile_cache nonlocal cached_res # Separate out static args if static_argnums is present tensor_args = args static_args = [] # TODO - move the hashing part of static_args to C++. static_args_hashed = [] if static_argnums is not None: ( tensor_args, static_args, static_args_hashed, ) = filter_tensor_and_static_args(args, static_argnums) # Now flatten the tensor args if HAS_TREE: flat_tensor_args = tree.flatten((tensor_args, kwargs)) else: flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs)) # Check if the fn is already compiled num_tensor_args = len(flat_tensor_args) flat_args_for_cache = flat_tensor_args + static_args_hashed cached_res = fn_id, fw_compiler_id, bw_compiler_id, num_tensor_args, hasher_type, *flat_args_for_cache, ) # Compile the function and save it in the cache if cached_res is None: # Save the args_spec for flat_tensor_args to unflatten while tracing _, tensor_args_spec = pytree.tree_flatten((tensor_args, kwargs)) out_spec = PytreeThunk() def flat_fn(*flat_tensor_args): # The input are flattened tensor args. Prepare the args in the # order that original function expects. Add static args as well. # They will appear as tensor constants in the traced graph. nonlocal out_spec, static_args tensor_args, kwargs = pytree.tree_unflatten( flat_tensor_args, tensor_args_spec ) if static_argnums is None: args = tensor_args else: args = rearrange(tensor_args, static_args, static_argnums) tree_out = fn(*args, **kwargs) flat_out, spec = pytree.tree_flatten(tree_out) for i in flat_out: is_known_type = False for j in KNOWN_TYPES: if isinstance(i, j): is_known_type = True break if not is_known_type: raise RuntimeError( f"Found {type(i)} in output, which is not a known type. " "If this type holds tensors, you need to register a pytree for it. " "See for a brief " "explanation why. If you don't need to register a pytree, please " "leave a comment explaining your use case and we'll make this more " "ergonomic to deal with" ) out_spec.set(spec) return flat_out compiled_fn = create_aot_autograd_function( flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state=torch.is_grad_enabled(), ).apply cached_res = (compiled_fn, out_spec) # Save the compiled_fn in the cache compile_cache.insert( fn_id, fw_compiler_id, bw_compiler_id, num_tensor_args, hasher_type, cached_res, *flat_args_for_cache, ) cached_fn, out_spec = cached_res out = cached_fn(*flat_tensor_args) return out_spec.unflatten(out) return returned_function
def num_of_recompilations(): """ Returns the numbers of recompilations since the last time cache was cleared. This is equivalent to the number of entries in the compilation cache. """ global compile_cache if compile_cache is None: return 0 return compile_cache.size() def clear_compile_cache(): """ Clears the compilation cache. """ global compile_cache if compile_cache is not None: compile_cache.clear() compile_cache = None
[docs]def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: """ Traces the forward and backward graph of :attr:`mod` using torch dispatch tracing mechanism. It is wrapper function, that underneath uses :func:`aot_function` to perform tracing and compilation. :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs to a new callable which is then compiled through :func:`aot_function`. .. warning:: This API is experimental and likely to change. Args: mod (Callable): A ``nn.Module`` module. args : args to be passed to :func:`aot_function` kwargs : kwargs to be passed to :func:`aot_function` Returns: Returns a ``nn.Module`` that retains the eager behavior of the original :attr:`mod`, but with forward and backward graph compiled. """ def functional_call(named_params, named_buffers, *args, **kwargs): params_and_buffers = {**named_params, **named_buffers} return _stateless.functional_call(mod, params_and_buffers, args, kwargs) compiled_f = aot_function(functional_call, *args, **kwargs) class AOTModule(nn.Module): def __init__(self): super(AOTModule, self).__init__() self.orig_module = mod def forward(self, *args, **kwargs): return compiled_f( dict(_named_parameters(mod, remove_duplicate=False)), dict(_named_buffers(mod, remove_duplicate=False)), *args, **kwargs, ) return AOTModule()
compiled_function = aot_function compiled_module = aot_module