importcopyimportloggingimportosimportpickleimportrandomfromcontextlibimportcontextmanagerfromfunctoolsimportpartialfromtypingimportCallable,Optional,Tuple,UnionimporttorchfromtorchimportSymIntimporttorch.fxasfximporttorch.nnasnnfromtorch._decompimportget_decompositionsfromtorch.fx.experimental.symbolic_shapesimportbind_symbolsfrom.aot_autogradimportaot_function,aot_module,make_boxed_compilerfrom.compile_utilsimportstrip_overloadsfrom.partitionersimport(default_partition,draw_graph,min_cut_rematerialization_partition,)importtorch.utils._pytreeaspytreelog=logging.getLogger(__name__)# These canonicalizations are needed here (and not decompositions), as the ops# we're trying to canonicalize to CompositeImplicitAutograd.def_canonicalize(fx_g):fornodeinfx_g.graph.nodes:ifnode.target==torch.ops.aten._to_copy:node.target=torch.ops.aten.tofx_g.recompile()returnfx_g@contextmanagerdef_disable_jit_autocast():old_jit_autocast_flag=torch._C._jit_set_autocast_mode(False)try:yieldfinally:torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
[docs]@make_boxed_compilerdefts_compile(fx_g:fx.GraphModule,inps)->Callable:""" Compiles the :attr:`fx_g` with Torchscript compiler. .. warning:: This API is experimental and likely to change. Args: fx_g(fx.GraphModule): The input Fx graph module to be compiled. Returns: Torch scripted model. """with_disable_jit_autocast():strip_overloads(fx_g)fornodeinfx_g.graph.nodes:if(node.target==torch.ops.aten._to_copyandlen(node.args)==1andlen(node.kwargs)==1and"dtype"innode.kwargs):node.target=torch.ops.aten.tofornodeinfx_g.graph.nodes:new_kwargs={}fork,vinnode.kwargs.items():ifisinstance(v,torch.device):v=v.typenew_kwargs[k]=vnode.kwargs=new_kwargsfx_g.graph.lint()fx_g.recompile()f=torch.jit.script(fx_g)torch._C._jit_pass_remove_mutation(f.graph)f=torch.jit.freeze(f.eval())f=torch.jit.optimize_for_inference(f)ifnotany(isinstance(t,torch._subclasses.FakeTensor)fortininps):f(*inps)returnf
[docs]@make_boxed_compilerdefnop(fx_g:fx.GraphModule,_)->Callable:""" Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler and can be used to check accuracy. .. warning:: This API is experimental and likely to change. """returnfx_g
classDebugInterpreter(fx.Interpreter):defrun(self,*args):self.symbol_mapping=bind_symbols(self.module,*args)super().run(*args)defrun_node(self,n):importsympydefsubst_symint(ni):ifnotisinstance(ni,SymInt):returnnir=sympy.expand(ni.node.expr.xreplace(self.symbol_mapping))assertlen(r.free_symbols)==0,rreturnint(r)defsubst_symint_tuple(nis):returntuple(subst_symint(ni)forniinnis)defcheck_significant_strides(a,b):ifsubst_symint(a.numel())>0:foridxinrange(a.ndim):ifsubst_symint(a.stride(idx))!=b.stride(idx)andsubst_symint(a.size(idx))>1:returnFalsereturnTruedefcheck(nv,rv,desc):assertcallable(desc)assertnv.dtype==rv.dtype,f"{desc()}: {nv.dtype} != {rv.dtype}"assertsubst_symint_tuple(nv.size())==rv.size(), \
f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"same_strides=check_significant_strides(nv,rv)assertsame_strides,f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"r=super().run_node(n)if'val'inn.meta:n_vals,n_spec=pytree.tree_flatten(n.meta['val'])r_vals,r_spec=pytree.tree_flatten(r)# TODO: There is some sort of problem where we record that an# operator returned a tuple/list, and then later it turns out the# real version of the operator returned a list/tuple. Need to# figure out what's actually going on here, the error itself is# harmless enough as we only getitem out the outputs.# assert n_spec == r_spec, f"{n_spec} != {r_spec}"assertlen(n_vals)==len(r_vals),f"{len(n_vals)} != {len(r_vals)}"fori,nv,rvinzip(range(len(n_vals)),n_vals,r_vals):ifnotisinstance(rv,torch.Tensor):continuecheck(nv,rv,lambda:f"output {i} where {self.symbol_mapping}")returnr@make_boxed_compilerdefdebug_nop(fx_g:fx.GraphModule,_)->Callable:""" Returns a (slow) interpreter over the FX graph module that also checks various debugging properties (e.g., that tracing strides matched real strides.) """returnDebugInterpreter(fx_g).run@make_boxed_compilerdefsimple_ts_compile(fx_g,_):strip_overloads(fx_g)f=torch.jit.script(fx_g)f=torch.jit.freeze(f.eval())returnfdefnnc_jit(f,static_argnums=None):returnaot_function(f,simple_ts_compile,static_argnums=static_argnums)aten=torch.ops.atendefault_decompositions={aten.detach,aten.gelu_backward,aten.leaky_relu_backward,aten.sigmoid_backward,aten.threshold_backward,aten.hardtanh_backward,aten.hardsigmoid_backward,aten.hardswish_backward,aten.tanh_backward,aten.silu_backward,aten.elu_backward,aten.cudnn_batch_norm,aten.cudnn_batch_norm_backward,aten.masked_fill.Scalar,aten.masked_fill.Tensor,aten.elu,aten.leaky_relu,aten.hardtanh,aten.hardswish,aten.hardsigmoid,aten.conj_physical,aten.is_same_size,}default_decompositions=get_decompositions(default_decompositions)@make_boxed_compilerdefprint_compile(fx_g,_):print(fx_g.code)returnfx_g
[docs]defmemory_efficient_fusion(fn:Union[Callable,nn.Module],static_argnums:Optional[Tuple[int]]=None,**kwargs,):""" Wrapper function over :func:`aot_function` and :func:`aot_module` to perform memory efficient fusion. It uses the :func:`min_cut_rematerialization_partition` partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs. .. warning:: This API is experimental and likely to change. Args: fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` that takes one ore more arguments. Must return one or more Tensors. static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark the arguments of the function as static. **kwargs: Any other overrides you want to make to the settings Returns: Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior of the original :attr:`fn`, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser. """config={"fw_compiler":ts_compile,"bw_compiler":ts_compile,"partition_fn":min_cut_rematerialization_partition,"decompositions":default_decompositions,"static_argnums":static_argnums,}config.update(kwargs)ifisinstance(fn,torch.nn.Module):returnaot_module(fn,**config)else:returnaot_function(fn,**config)
defdebug_compile(fx_g,inps):fx_g.to_folder("foo")print(f"""############################################################### To minimize FX graph, copy and paste the below and run it ###############################################################import torchimport torch.fx as fxfrom functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocessinps = {[(i.shape,i.dtype)foriininps]}inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]from foo import FxModulemod = FxModule().cuda()with torch.jit.fuser("fuser2"): # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)""")fromfooimportFxModuleFxModule().cuda()(*inps)returnts_compile(fx_g,inps)graph_index=0defget_inputs(input_data_path):""" Return a random input for the given inputs meta generated from _save_fx_default. """inputs=[]with(open(input_data_path,"rb"))asf:inputs_meta=pickle.load(f)inputs=[]formetaininputs_meta:iflen(meta)==1:type=metainput=type(random.rand())else:type,shape,stride,dtype,device=metaifdtypein{torch.int,torch.int32,torch.int64,torch.bool,torch.int,torch.uint8,int,float,}:input=torch.randint(0,1,shape,dtype=dtype,device=device)else:input=torch.rand(shape,dtype=dtype,device=device)inputs.append(input)returninputsdef_save_fx_default(current_name,folder_name,dump_example_input,gm,example_inputs):""" The forward, backward, and joint computation graph will be stored in {folder_name}/{current_name}/{current_name}_forward_{graph_index}, {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. The input shape of the graphs will be stored in the .input files. These files can be loaded with pickle, and is a list of format (type, shape, stride, dtype, device). In the case of type = int or float, it is just (type,). For joint graph input, it is a nested list [[],[]] where the two inner lists have the same format. If dump_example_input is True, example_inputs will be stored in .pt file. Since each function might produce multiple graphs, the graph_index is used to distinguish difference graphs """fromfunctorch.compileimportaot_module_simplifieddefget_input_meta(args):input_meta=[]iflen(args)>0andisinstance(args[0],tuple):# joint inputinput_meta+=get_input_meta(args[0])input_meta+=get_input_meta(args[1])returninput_metaforarginargs:iftype(arg)==intortype(arg)==float:input_meta.append((type(arg),))else:input_meta.append((type(arg),arg.shape,arg.stride(),arg.dtype,arg.device))returninput_metadefgraph_saver_helper(gm_to_save,args,type_name):globalgraph_indexiflen(gm_to_save.graph.nodes)==0:log.log(logging.WARNING,f"No nodes in graph {current_name}_{type_name}_{graph_index}.",)returngm=copy.deepcopy(gm_to_save)gm.graph.set_codegen(torch.fx.graph.CodeGen())# remove codegengm.recompile()input_meta=get_input_meta(args)isExist=os.path.exists(f"{folder_name}/{current_name}")ifnotisExist:os.makedirs(f"{folder_name}/{current_name}")gm.to_folder(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}")pickle.dump(input_meta,open(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input",# noqa: B950"wb",),)# noqa: E501ifdump_example_input:torch.save(args,f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt",# noqa: B950)# noqa: E501defgraph_saver_forward(gm,fw_args):graph_saver_helper(gm,fw_args,"forward")returngmdefgraph_saver_backward(gm,bw_args):graph_saver_helper(gm,bw_args,"backward")globalgraph_indexgraph_index+=1returngmdefgraph_saver_joint(gm,joint_args):graph_saver_helper(gm,joint_args,"joint")returndefault_partition(gm,joint_args)returnaot_module_simplified(gm,example_inputs,fw_compiler=graph_saver_forward,bw_compiler=graph_saver_backward,partition_fn=graph_saver_joint,decompositions=default_decompositions,)# WARNING: This isn't tested anywhere!!defgraph_dumper_aot(current_name,folder_name,dump_example_input=False):""" Dump the forward, backward, and joint computation graph. Example Usage: save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) optimize_ctx = torchdynamo.optimize( save_fx_func ) with torch.enable_grad(): with optimize_ctx: result = forward_and_backward_pass(model, example_inputs) """globalgraph_indexgraph_index=0returnpartial(_save_fx_default,current_name,folder_name,dump_example_input)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.