importtorchimporttorch.nnasnnimporttorch.overridesfromtorch.nn.modules.moduleimport_addindentfromtorch.packageimportPackageImporter,PackageExporterimportlinecachefromtypingimportType,Dict,List,Any,Union,Optional,Setfrom.graphimportGraph,_PyTreeCodeGen,_is_from_torch,_custom_builtins,PythonCodefrom._compatibilityimportcompatibilityfromtorch.packageimportImporter,sys_importerimportcopyimportitertoolsimportsysimporttracebackfrompathlibimportPathimportosimportwarnings__all__=["reduce_graph_module","reduce_package_graph_module","reduce_deploy_graph_module","GraphModule"]_USER_PRESERVED_ATTRIBUTES_KEY="_user_preserved_attributes"# Normal exec loses the source code, however we can work with# the linecache module to recover it.# Using _exec_with_source will add it to our local cache# and then tools like TorchScript will be able to get source info.class_EvalCacheLoader:def__init__(self):self.eval_cache={}self.next_id=0defcache(self,src:str,globals:Dict[str,Any],co_fields=None):"""Store the source in a private cache, and add a lazy entry in linecache that allows the source to be retrieved by 'filename'. Args: src (str): The module source to cache globals (dict): The module globals Returns: str: The cache key (and dummy filename) generated for src. """key=self._get_key()ifco_fields:key+=f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"self.eval_cache[key]=src# Don't mutate globals so that this loader is only used# to populate linecache, and doesn't interact with other modules# that might check `__loader__`globals_copy=globals.copy()globals_copy['__file__']=keyglobals_copy['__name__']=keyglobals_copy['__loader__']=selflinecache.lazycache(key,globals_copy)returnkey# Part of the loader protocol (PEP 302)# linecache will use this method when trying to find source codedefget_source(self,module_name)->Optional[str]:ifmodule_nameinself.eval_cache:returnself.eval_cache[module_name]returnNonedef_get_key(self):key=f'<eval_with_key>.{self.next_id}'self.next_id+=1returnkey_loader=_EvalCacheLoader()def_exec_with_source(src:str,globals:Dict[str,Any],co_fields=None):key=_loader.cache(src,globals,co_fields)exec(compile(src,key,'exec'),globals)def_forward_from_src(src:str,globals:Dict[str,Any],co_fields=None):# avoid mutating the passed in dictglobals_copy=globals.copy()_exec_with_source(src,globals_copy,co_fields)forward_fn=globals_copy['forward']delglobals_copy['forward']returnforward_fndef_format_import_statement(name:str,obj:Any,importer:Importer)->str:ifnamein_custom_builtins:return_custom_builtins[name].import_strif_is_from_torch(name):return'import torch'module_name,attr_name=importer.get_name(obj)returnf'from {module_name} import {attr_name} as {name}'def_format_import_block(globals:Dict[str,Any],importer:Importer):import_strs:Set[str]=set()forname,objinglobals.items():import_strs.add(_format_import_statement(name,obj,importer))return'\n'.join(import_strs)@compatibility(is_backward_compatible=True)defreduce_graph_module(body:Dict[Any,Any],import_block:str)->torch.nn.Module:# BC: attribute name was changed from `code` to `_code` to facilitate# making `code` into a property and adding a docstring to itfn_src=body.get('_code')orbody['code']forward=_forward_from_src(import_block+fn_src,{})return_deserialize_graph_module(forward,body)@compatibility(is_backward_compatible=True)defreduce_package_graph_module(importer:PackageImporter,body:Dict[Any,Any],generated_module_name:str)->torch.nn.Module:forward=importer.import_module(generated_module_name).forwardreturn_deserialize_graph_module(forward,body)@compatibility(is_backward_compatible=True)defreduce_deploy_graph_module(importer:PackageImporter,body:Dict[Any,Any],import_block:str)->torch.nn.Module:ns={}ns["__builtins__"]=importer.patched_builtinsfn_src=body.get('_code')assertfn_srcisnotNoneforward=_forward_from_src(import_block+fn_src,ns)return_deserialize_graph_module(forward,body)# We create a dummy class here because symbolic_trace pulls the forward()# function off of the class, rather than the instance. This class is used# in _deserialize_graph_module() below.class_CodeOnlyModule(torch.nn.Module):def__init__(self,body):super().__init__()self.__dict__=bodydef_deserialize_graph_module(forward,body:Dict[Any,Any])->torch.nn.Module:""" Deserialize a GraphModule given the dictionary of the original module, using the code to reconstruct the graph. We delete the actual graph before saving the dictionary so that changes to the in-memory graph format do not get serialized. """# Try to retrieve the forward source in a backward-compatible way_CodeOnlyModule.forward=forwardtracer_cls=body.get('_tracer_cls')iftracer_clsisNone:from._symbolic_traceimportTracertracer_cls=Tracergraphmodule_cls_name=body.get('_graphmodule_cls_name','GraphModule')# This is a workaround for a mypy linter issue related to# passing base class as an argument - https://github.com/python/mypy/issues/5865.cls_tracer:Any=tracer_clsclassKeepModules(cls_tracer):# we shouldn't trace into any of the submodules,# because they were not traced in the original GraphModuledefis_leaf_module(self,_:torch.nn.Module,__:str)->bool:returnTruecom=_CodeOnlyModule(body)tracer_extras=body.get('_tracer_extras',{})graph=KeepModules().trace(com,**tracer_extras)# Manually set Tracer class on the reconstructed Graph, to avoid# referencing the private local subclass KeepModules.graph._tracer_cls=tracer_clsgm=GraphModule(com,graph,class_name=graphmodule_cls_name)# The GraphModule constructor only retains attributes referenced by the graph.# In this case, our goal is return a GraphModule as close to identical as the one# put into the package. If any additional attributes were present in body,# we should keep them.fork,vinbody.items():ifnothasattr(gm,k):setattr(gm,k,v)returngm# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'# This installs empty Modules where none exist yet if they are subpaths of targetdef_copy_attr(from_module:torch.nn.Module,to_module:torch.nn.Module,target:str):*prefix,field=target.split('.')foriteminprefix:f=getattr(from_module,item)t=getattr(to_module,item,None)iffist:# we have already installed one of its parents# (e.g. target = root.linear.weight, but we have already installed root.linear)# once we install a parent, we no longer need to copy the children# since all the needed properties will already be presentreturniftisNone:t=torch.nn.Module()setattr(to_module,item,t)from_module,to_module=f,torig=getattr(from_module,field)# If it is a tensor and not a parameter attribute of a module, it should be a named buffer.# So, we register it as a named buffer in the target module.ifisinstance(orig,torch.Tensor)andnotisinstance(orig,torch.nn.Parameter):to_module.register_buffer(field,orig)else:setattr(to_module,field,orig)# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module# This installs empty Modules where none exist yet if they are subpaths of targetdef_assign_attr(from_obj:Any,to_module:torch.nn.Module,target:str):*prefix,field=target.split('.')foriteminprefix:t=getattr(to_module,item,None)iftisNone:t=torch.nn.Module()setattr(to_module,item,t)to_module=t# If it is a tensor and not a parameter attribute of a module, it should be a named buffer.# So, we register it as a named buffer in the target module.ifisinstance(from_obj,torch.Tensor)andnotisinstance(from_obj,torch.nn.Parameter):to_module.register_buffer(field,from_obj)else:setattr(to_module,field,from_obj)class_WrappedCall:def__init__(self,cls,cls_call):self.cls=clsself.cls_call=cls_call# Previously, if an error occurred when valid# symbolically-traced code was run with an invalid input, the# user would see the source of the error as coming from# `File "<eval_with_key_N">`, where N is some number. We use# this function to generate a more informative error message. We# return the traceback itself, a message explaining that the# error occurred in a traced Module's generated forward# function, and five lines of context surrounding the faulty# line@staticmethoddef_generate_error_message(frame_summary:traceback.FrameSummary)->str:# auxiliary variables (for readability)err_lineno=frame_summary.linenoasserterr_linenoisnotNoneline=frame_summary.lineassertlineisnotNoneerr_line_len=len(line)all_src_lines=linecache.getlines(frame_summary.filename)# constituent substrings of the error messagetb_repr=traceback.format_exc()custom_msg=("Call using an FX-traced Module, "f"line {err_lineno} of the traced Module's ""generated forward function:")before_err="".join(all_src_lines[err_lineno-2:err_lineno])marker="~"*err_line_len+"~~~ <--- HERE"err_and_after_err="\n".join(all_src_lines[err_lineno:err_lineno+2])# joined messagereturn"\n".join([tb_repr,custom_msg,before_err,marker,err_and_after_err])def__call__(self,obj,*args,**kwargs):try:ifself.cls_callisnotNone:returnself.cls_call(obj,*args,**kwargs)else:returnsuper(self.cls,obj).__call__(*args,**kwargs)# type: ignore[misc]exceptExceptionase:asserte.__traceback__topmost_framesummary:traceback.FrameSummary= \
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]# type: ignore[arg-type]if"eval_with_key"intopmost_framesummary.filename:print(_WrappedCall._generate_error_message(topmost_framesummary),file=sys.stderr)raisee.with_traceback(None)else:raisee
[docs]@compatibility(is_backward_compatible=True)classGraphModule(torch.nn.Module):""" GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated from that ``graph``. .. warning:: When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically regenerated. However, if you edit the contents of the ``graph`` without reassigning the ``graph`` attribute itself, you must call ``recompile()`` to update the generated code. """def__new__(cls:'Type[GraphModule]',*args,**kwargs):# each instance of a graph module needs its own forward method# so create a new singleton class for each instance.# it is a subclass of the user-defined class, the only difference# is an extra layer to install the forward method# address issue described at https://github.com/pytorch/pytorch/issues/63883# in other words, traverse class hierarchy to fix the redundant class definition problemfortincls.__mro__:c=t.__qualname__.split('.')[-1]ifc!='GraphModuleImpl':cls=tbreakclassGraphModuleImpl(cls):# type: ignore[misc, valid-type]passreturnsuper().__new__(GraphModuleImpl)
[docs]@compatibility(is_backward_compatible=True)def__init__(self,root:Union[torch.nn.Module,Dict[str,Any]],graph:Graph,class_name:str='GraphModule'):""" Construct a GraphModule. Args: root (Union[torch.nn.Module, Dict[str, Any]): ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. In the case that ``root`` is a Module, any references to Module-based objects (via qualified name) in the Graph's Nodes' ``target`` field will be copied over from the respective place within ``root``'s Module hierarchy into the GraphModule's module hierarchy. In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be looked up directly in the dict's keys. The object mapped to by the Dict will be copied over into the appropriate place within the GraphModule's module hierarchy. graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all error messages will report as originating from ``GraphModule``. It may be helpful to set this to ``root``'s original name or a name that makes sense within the context of your transform. """super().__init__()self.__class__.__name__=class_nameifisinstance(root,torch.nn.Module):ifhasattr(root,'training'):self.training=root.training# When we pickle/unpickle graph module, we don't want to drop any module or attributes.ifisinstance(root,_CodeOnlyModule):fork,_inroot.named_children():_copy_attr(root,self,k)fork,_inroot.named_buffers():_copy_attr(root,self,k)fork,_inroot.named_parameters():_copy_attr(root,self,k)fornodeingraph.nodes:ifnode.opin['get_attr','call_module']:assertisinstance(node.target,str)_copy_attr(root,self,node.target)elifisinstance(root,dict):targets_to_copy=[]fornodeingraph.nodes:ifnode.opin['get_attr','call_module']:assertisinstance(node.target,str)ifnode.targetnotinroot:raiseRuntimeError('Node '+str(node)+' referenced target '+node.target+' but that target was not provided in ``root``!')targets_to_copy.append(node.target)# Sort targets in ascending order of the # of atoms.# This will ensure that less deeply nested attributes are assigned# before more deeply nested attributes. For example, foo.bar# will be assigned before foo.bar.baz. Otherwise, we might assign# the user-provided ``foo.bar`` and wipe out the previously-assigned# ``foo.bar.baz``targets_to_copy.sort(key=lambdat:t.count('.'))fortarget_to_copyintargets_to_copy:_assign_attr(root[target_to_copy],self,target_to_copy)else:raiseRuntimeError('Unsupported type '+str(root)+' passed for root!')self.graph=graph# Store the Tracer class responsible for creating a Graph separately as part of the# GraphModule state, except when the Tracer is defined in a local namespace.# Locally defined Tracers are not pickleable. This is needed because torch.package will# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer# to re-create the Graph during deserialization.self._tracer_cls=Noneifself.graph._tracer_clsand'<locals>'notinself.graph._tracer_cls.__qualname__:self._tracer_cls=self.graph._tracer_clsself._tracer_extras={}ifself.graph._tracer_extras:self._tracer_extras=self.graph._tracer_extras# Dictionary to store metadataself.meta:Dict[str,Any]={}
# TorchScript breaks trying to compile the graph setter because of the# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842## Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway__jit_unused_properties__=['graph']@propertydefgraph(self)->Graph:""" Return the ``Graph`` underlying this ``GraphModule`` """returnself._graph@graph.setterdefgraph(self,g:Graph)->None:""" Set the underlying ``Graph`` for this ``GraphModule``. This will internally recompile the ``GraphModule`` so that the generated ``forward()`` function corresponds to ``g`` """assertisinstance(g,Graph),f'Expected a Graph instance, but got {type(g)}'self._graph=gg.owning_module=selfself.recompile()
[docs]@compatibility(is_backward_compatible=False)defto_folder(self,folder:Union[str,os.PathLike],module_name:str="FxModule"):"""Dumps out module to ``folder`` with ``module_name`` so that it can be imported with ``from <folder> import <module_name>`` Args: folder (Union[str, os.PathLike]): The folder to write the code out to module_name (str): Top-level name to use for the ``Module`` while writing out the code """folder=Path(folder)Path(folder).mkdir(exist_ok=True)torch.save(self.state_dict(),folder/'state_dict.pt')tab=" "*4custom_builtins='\n'.join([v.import_strforvin_custom_builtins.values()])model_str=f"""import torch{custom_builtins}from torch.nn import *class {module_name}(torch.nn.Module): def __init__(self): super().__init__()"""def_gen_model_repr(module_name:str,module:torch.nn.Module)->Optional[str]:safe_reprs=[nn.Linear,nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d]iftype(module)insafe_reprs:returnf"{module.__repr__()}"else:returnNoneblobified_modules=[]formodule_name,moduleinself.named_children():module_str=_gen_model_repr(module_name,module)ifmodule_strisNone:module_file=folder/f'{module_name}.pt'torch.save(module,module_file)blobified_modules.append(module_name)module_repr=module.__repr__().replace('\r',' ').replace('\n',' ')module_str=f"torch.load(r'{module_file}') # {module_repr}"model_str+=f"{tab*2}self.{module_name} = {module_str}\n"forbuffer_name,bufferinself._buffers.items():ifbufferisNone:continuemodel_str+=f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"forparam_name,paraminself._parameters.items():ifparamisNone:continuemodel_str+=f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"model_str+=f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"model_str+=f"{_addindent(self.code,4)}\n"module_file=folder/'module.py'module_file.write_text(model_str)init_file=folder/'__init__.py'init_file.write_text('from .module import *')iflen(blobified_modules)>0:warnings.warn("Was not able to save the following children modules as reprs -"f"saved as pickled files instead: {blobified_modules}")
[docs]@compatibility(is_backward_compatible=True)defadd_submodule(self,target:str,m:torch.nn.Module)->bool:""" Adds the given submodule to ``self``. This installs empty Modules where none exist yet if they are subpaths of ``target``. Args: target: The fully-qualified string name of the new submodule (See example in ``nn.Module.get_submodule`` for how to specify a fully-qualified string.) m: The submodule itself; the actual object we want to install in the current Module Return: bool: Whether or not the submodule could be inserted. For this method to return True, each object in the chain denoted by ``target`` must either a) not exist yet, or b) reference an ``nn.Module`` (not a parameter or other attribute) """*prefix,field=target.split('.')mod:torch.nn.Module=selfforiteminprefix:submod=getattr(mod,item,None)ifsubmodisNone:submod=torch.nn.Module()setattr(mod,item,submod)ifnotisinstance(submod,torch.nn.Module):returnFalsemod=submodmod.add_module(field,m)returnTrue
[docs]@compatibility(is_backward_compatible=True)defdelete_submodule(self,target:str)->bool:""" Deletes the given submodule from ``self``. The module will not be deleted if ``target`` is not a valid target. Args: target: The fully-qualified string name of the new submodule (See example in ``nn.Module.get_submodule`` for how to specify a fully-qualified string.) Returns: bool: Whether or not the target string referenced a submodule we want to delete. A return value of ``False`` means that the ``target`` was not a valid reference to a submodule. """atoms=target.split(".")path,target_submod=atoms[:-1],atoms[-1]mod:torch.nn.Module=self# Get the parent moduleforiteminpath:ifnothasattr(mod,item):returnFalsemod=getattr(mod,item)ifnotisinstance(mod,torch.nn.Module):returnFalseifnothasattr(mod,target_submod):returnFalseifnotisinstance(getattr(mod,target_submod),torch.nn.Module):returnFalsedelattr(mod,target_submod)returnTrue
[docs]@compatibility(is_backward_compatible=True)defdelete_all_unused_submodules(self)->None:""" Deletes all unused submodules from ``self``. A Module is considered "used" if any one of the following is true: 1. It has children that are used 2. Its forward is called directly via a ``call_module`` node 3. It has a non-Module attribute that is used from a ``get_attr`` node This method can be called to clean up an ``nn.Module`` without manually calling ``delete_submodule`` on each unused submodule. """used:List[str]=[]fornodeinself.graph.nodes:ifnode.op=="call_module"ornode.op=="get_attr":# A list of strings representing the different parts# of the path. For example, `foo.bar.baz` gives us# ["foo", "bar", "baz"]fullpath=node.target.split(".")# If we're looking at multiple parts of a path, join# join them with a dot. Otherwise, return that single# element without doing anything to it.defjoin_fn(x:str,y:str)->str:return'.'.join([x,y]ifyelse[x])# Progressively collect all the names of intermediate# modules. For example, if we have the target# `foo.bar.baz`, we'll add `foo`, `foo.bar`, and# `foo.bar.baz` to the list.forpathinitertools.accumulate(fullpath,join_fn):used.append(path)# For a `call_module` node, also register all recursive submodules# as usedifnode.op=="call_module":try:submod=self.get_submodule(node.target)forsubmod_name,_insubmod.named_modules():ifsubmod_name!='':used.append('.'.join([node.target,submod_name]))exceptAttributeError:# Node referenced nonexistent submodule, don't need to# worry about GCing anythingpassto_delete=[nameforname,_inself.named_modules()ifnamenotinused]fornameinto_delete:self.delete_submodule(name)
@propertydefcode(self)->str:""" Return the Python code generated from the ``Graph`` underlying this ``GraphModule``. """ifnothasattr(self,'_code'):raiseRuntimeError('Code has not been generated! Please report a bug to PyTorch')returnself._code
[docs]@compatibility(is_backward_compatible=True)defrecompile(self)->PythonCode:""" Recompile this GraphModule from its ``graph`` attribute. This should be called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date. """ifisinstance(self._graph._codegen,_PyTreeCodeGen):self._in_spec=self._graph._codegen.pytree_info.in_specself._out_spec=self._graph._codegen.pytree_info.out_specpython_code=self._graph.python_code(root_module='self')self._code=python_code.srccls=type(self)co_fields=self._graph._co_fieldsifhasattr(self._graph,'_co_fields')else{}cls.forward=_forward_from_src(self._code,python_code.globals,co_fields)# Determine whether this class explicitly defines a __call__ implementation# to wrap. If it does, save it in order to have wrapped_call invoke it.# If it does not, wrapped_call can use a dynamic call to super() instead.# In most cases, super().__call__ should be torch.nn.Module.__call__.# We do not want to hold a reference to Module.__call__ here; doing so will# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.cls_call=cls.__call__if"__call__"invars(cls)elseNoneif'_wrapped_call'notinvars(cls):cls._wrapped_call=_WrappedCall(cls,cls_call)# type: ignore[attr-defined]defcall_wrapped(self,*args,**kwargs):returnself._wrapped_call(self,*args,**kwargs)cls.__call__=call_wrappedreturnpython_code
# Passing Tracer as argument allows subclasses extending fx.GraphModule# define their own Tracer (extending fx.Tracer).def__reduce_deploy__(self,importer:Importer):dict_without_graph=self.__dict__.copy()dict_without_graph['_graphmodule_cls_name']=self.__class__.__name__deldict_without_graph['_graph']python_code=self.recompile()import_block=_format_import_block(python_code.globals,importer)return(reduce_deploy_graph_module,(dict_without_graph,import_block))def__reduce_package__(self,exporter:PackageExporter):dict_without_graph=self.__dict__.copy()dict_without_graph['_graphmodule_cls_name']=self.__class__.__name__deldict_without_graph['_graph']generated_module_name=f'fx-generated._{exporter.get_unique_id()}'python_code=self.recompile()import_block=_format_import_block(python_code.globals,exporter.importer)module_code=import_block+self.codeexporter.save_source_string(generated_module_name,module_code)return(reduce_package_graph_module,(dict_without_graph,generated_module_name))def__reduce__(self):""" Serialization of GraphModule. We serialize only the generated code, not the underlying ``Graph``. This is because ``Graph`` does not have on-disk backward-compatibility guarantees, whereas Python source code does. On the deserialization side, we symbolically trace through the generated code to regenerate the underlying ``Graph`` """dict_without_graph=self.__dict__.copy()python_code=self.recompile()import_block=_format_import_block(python_code.globals,sys_importer)deldict_without_graph['_graph']return(reduce_graph_module,(dict_without_graph,import_block))# because __reduce__ is defined for serialization,# we need to define deepcopy otherwise it will call __reduce__# and cause symbolic tracing to occur every time we try to copy the objectdef__deepcopy__(self,memo):res=type(self).__new__(type(self))memo[id(self)]=resfake_mod=torch.nn.Module()fake_mod.__dict__=copy.deepcopy(self.__dict__,memo)GraphModule.__init__(res,fake_mod,fake_mod.__dict__['_graph'])# hooks are lost during `GraphModule.__init__`, so we need to copy over# them explicitly, note right now we are only copying state_dict related# hooks, to reduce bc-related issues, we can copy forward/backward related# hooks in the future as well if neededextra_preserved_attrs=["_state_dict_hooks","_load_state_dict_pre_hooks","_load_state_dict_post_hooks"]forattrinextra_preserved_attrs:ifattrinself.__dict__:setattr(res,attr,copy.deepcopy(self.__dict__[attr],memo))res.meta=copy.deepcopy(getattr(self,'meta',{}),memo)if_USER_PRESERVED_ATTRIBUTES_KEYinres.meta:forattr_name,attrinres.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():setattr(res,attr_name,attr)returnresdef__copy__(self):res=GraphModule(self,self.graph)res.meta=getattr(self,'meta',{})returnres
[docs]@compatibility(is_backward_compatible=False)defprint_readable(self,print_output=True):""" Return the Python code generated for current GraphModule and its children GraphModules """verbose_python_code=self._graph.python_code(root_module='self',verbose=True)module_code=verbose_python_code.srcmodule_code=module_code.lstrip('\n')module_code=f"class {self._get_name()}(torch.nn.Module):\n"+module_codemodule_code=_addindent(module_code,4)submodule_code_list=[""]forsubmoduleinself.children():ifisinstance(submodule,GraphModule):submodule_code_list.append(submodule.print_readable(print_output=False))submodule_code="\n".join(submodule_code_list)submodule_code=_addindent(submodule_code,4)output=module_code+submodule_codeifprint_output:print(module_code+submodule_code)returnoutput
def__str__(self)->str:orig_str=super().__str__()print_readable_reminder="# To see more debug info, please use `graph_module.print_readable()`"return'\n'.join([orig_str,self._code,print_readable_reminder])def_replicate_for_data_parallel(self):new_gm=self.__copy__()new_gm._is_replica=Truereturnnew_gm
# workarounds for issues in __torch_function__# WAR for __torch_function__ not handling tensor lists,# fix is in https://github.com/pytorch/pytorch/pull/34725# orig_cat = torch.cat# def patched_cat(*args, **kwargs):# tensors = args[0]# for t in tensors:# if isinstance(t, Proxy):# return t.__torch_function__(patched_cat, (), args, kwargs)# return orig_cat(*args, **kwargs)# patched_cat.__module__ = 'torch'# patched_cat.__name__ = 'cat'# torch.cat = patched_cat
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.