import torch
import torch.fx as fx
import torch.nn as nn
from functools import partial
from typing import Callable, Iterable, Optional, Tuple, Union
from .aot_autograd import aot_function, aot_module
from .decompositions import decomposition_table
from .partitioners import draw_graph, min_cut_rematerialization_partition
import time

[docs]def ts_compile(fx_g: fx.GraphModule, _) -> Callable: """ Compiles the :attr:`fx_g` with Torchscript compiler. .. warning:: This API is experimental and likely to change. Args: fx_g(fx.GraphModule): The input Fx graph module to be compiled. Returns: Torch scripted model. """ for node in fx_g.graph.nodes: if in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty): if node.args[1] == []: args = list(node.args) args[1] = [1] node.args = tuple(args) elif == torch.ops.aten.avg_pool2d_backward: # Handle empty strides if node.args[3] == []: args = list(node.args) args[3] = [1, 1] node.args = tuple(args) for node in fx_g.graph.nodes: new_kwargs = {} for k, v in node.kwargs.items(): if isinstance(v, torch.device): v = v.type new_kwargs[k] = v node.kwargs = new_kwargs fx_g.graph.lint() # print(set([ for i in fx_g.graph.nodes if i.op == 'call_function'])) # Works around this NVFuser issue: for i in range(1000): attr = f"_tensor_constant{i}" if hasattr(fx_g, attr): setattr(fx_g, attr, getattr(fx_g, attr).cuda()) else: break fx_g.recompile() f = torch.jit.script(fx_g) torch._C._jit_pass_remove_mutation(f.graph) f = torch.jit.freeze(f.eval()) f = torch.jit.optimize_for_inference(f) return f
def tensorexpr_compile(fx_module: fx.GraphModule, flat_args) -> Callable: """Compiles the given fx_module using TensorExpr Kernel""" inp_devices = set([i.device for i in flat_args if isinstance(i, torch.Tensor)]) assert len(inp_devices) == 1 inp_device = list(inp_devices)[0] inputs = list() output_refs = list() for node in fx_module.graph.nodes: if node.op == "placeholder": inputs.append(node) elif node.op == "output": outputs = node.args[0] if not isinstance(outputs, Iterable): outputs = (outputs,) new_outputs = [] for idx, output in enumerate(outputs): # Appends (bool, idx) pairs # if True, read from kernel outputs # if False, read from kernel inputs if output in inputs: output_refs.append((False, inputs.index(output))) elif output in outputs[:idx]: output_refs.append((True, output_refs[outputs.index(output)][1])) else: output_refs.append((True, len(new_outputs))) new_outputs.append(output) node.args = (tuple(new_outputs),) fx_module.graph.lint() fx_module.recompile() for i in range(0, 100): attr = f"_tensor_constant{i}" if hasattr(fx_module, attr): setattr(fx_module, attr, getattr(fx_module, attr).to(inp_device)) else: break jit_module = torch.jit.trace(fx_module, flat_args) jit_module = torch.jit.freeze(jit_module.eval()) torch._C._jit_trace_module(jit_module._c, tuple(flat_args)) torch._C._te.remove_unused_self_argument(jit_module.graph) torch._C._te.annotate_input_shapes(jit_module.graph, tuple(flat_args)) torch._C._jit_pass_lower_all_tuples(jit_module.graph) te_kernel = torch._C._te.TensorExprKernel(jit_module.graph) def f(*args): outs = if not isinstance(outs, tuple) and not isinstance(outs, list): outs = (outs,) real_outs = [] for out in output_refs: if out[0]: real_outs.append(outs[out[1]]) else: real_outs.append(args[out[1]]) return real_outs return f def _draw_graph_compile(fx_g, _, name, clear_meta=True): print(fx_g.code) draw_graph(fx_g, name, clear_meta=clear_meta) return fx_g def draw_graph_compile(name): return partial(_draw_graph_compile, name=name) def _tvm_compile( fx_module, example_inputs, target=None, tuning_logfile=None, use_ansor_tuning=False ): import tvm from tvm import relay, auto_scheduler from tvm.contrib import graph_executor import os # Find the target and device for TVM. dev = tvm.cpu(0) if target is None: raise ValueError("Setup the TVM target correctly.") elif isinstance(target, str): if "cuda" in target: dev = tvm.cuda(0) target = elif isinstance(target, if "cuda" in target.keys: dev = tvm.cuda(0) # JIT the model and pass it to Torchscript to Relay frontend parser. TVM # tutorials suggest tracing instead of scripting. The main reason is to # avoid Pythonic computation to show up in JIT module. However, with Python # key tracing, AOT Autograd leads to simpler graphs. Therefore, we use # scripting here to retrieve the JIT module. jit_mod = torch.jit.script(fx_module) shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # TVM Autotuning if use_ansor_tuning: tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) if tuning_logfile is None: log_file = f"{time.time()}.json" else: log_file = f"{tuning_logfile}.json" if len(tasks) != 0: tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( num_measure_trials=20000, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], # early_stopping=1000, # verbose=2, ) tuner.tune(tune_option) elif tuning_logfile is not None: log_file = f"{tuning_logfile}.json" if use_ansor_tuning or tuning_logfile is not None: assert os.path.exists(log_file) with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True} ): lib =, target=target, params=params) else: with tvm.transform.PassContext(opt_level=3): lib =, target=target, params=params) # Get a graph executor graph module m = graph_executor.GraphModule(lib["default"](dev)) def exec_tvm(*args): for idx, arg in enumerate(args, 0): if arg.dim() != 0: m.set_input( f"inp_{idx}", tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(arg.contiguous())), ) outs = [ torch.utils.dlpack.from_dlpack(m.get_output(i).to_dlpack()) for i in range(m.get_num_outputs()) ] return outs return exec_tvm def tvm_compile(target, tuning_logfile=None, use_ansor_tuning=False): return partial( _tvm_compile, target=target, tuning_logfile=tuning_logfile, use_ansor_tuning=use_ansor_tuning, )
[docs]def nop(fx_g: fx.GraphModule, _) -> Callable: """ Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler and can be used to check accuracy. .. warning:: This API is experimental and likely to change. """ return fx_g
def simple_ts_compile(fx_g, _): f = torch.jit.script(fx_g) f = torch.jit.freeze(f.eval()) return f def nnc_jit(f, static_argnums=None): return aot_function(f, simple_ts_compile, static_argnums=static_argnums) aten = torch.ops.aten default_decompositions = set( [ aten.detach, aten.gelu_backward, aten._log_softmax_backward_data, aten.leaky_relu_backward, aten.sigmoid_backward, aten.threshold_backward, aten.hardtanh_backward, aten.hardsigmoid_backward, aten.hardswish_backward, aten.tanh_backward, aten.silu_backward, ] ) default_decompositions = { k: v for k, v in decomposition_table.items() if k in default_decompositions } def print_compile(fx_g, _): print(fx_g.code) return fx_g
[docs]def memory_efficient_fusion( fn: Union[Callable, nn.Module], static_argnums: Optional[Tuple[int]] = None ): """ Wrapper function over :func:`aot_function` and :func:`aot_module` to perform memory efficient fusion. It uses the :func:`min_cut_rematerialization_partition` partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs. .. warning:: This API is experimental and likely to change. Args: fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` that takes one ore more arguments. Must return one or more Tensors. static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark the arguments of the function as static. Returns: Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior of the original :attr:`fn`, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser. """ config = { "fw_compiler": ts_compile, "bw_compiler": ts_compile, "partition_fn": min_cut_rematerialization_partition, "hasher_type": "StaticShapeHasher", "decompositions": default_decompositions, "static_argnums": static_argnums, } if isinstance(fn, torch.nn.Module): return aot_module(fn, **config) else: return aot_function(fn, **config)
def debug_compile(fx_g, inps): fx_g.to_folder("foo") print( f""" ############################################################## # To minimize FX graph, copy and paste the below and run it # ############################################################## import torch import torch.fx as fx from functorch.compile import minifier, check_nvfuser_subprocess inps = {[(i.shape, i.dtype) for i in inps]} inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] from foo import FxModule mod = FxModule().cuda() with torch.jit.fuser("fuser2"): minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) """ ) from foo import FxModule FxModule().cuda()(*inps) return ts_compile(fx_g, inps)