[docs]@dataclasses.dataclassclassInputSpec:kind:InputKindarg:ArgumentSpectarget:Optional[str]persistent:Optional[bool]=Nonedef__post_init__(self):ifself.kind==InputKind.BUFFER:assert(self.persistentisnotNone),"Failed to specify persistent flag on BUFFER."assertisinstance(self.arg,(TensorArgument,SymIntArgument,SymFloatArgument,SymBoolArgument,ConstantArgument,CustomObjArgument,TokenArgument,),),f"got {type(self.arg)}"
[docs]@dataclasses.dataclassclassExportGraphSignature:""" :class:`ExportGraphSignature` models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees. Export Graph is functional and does not access "states" like parameters or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph. The ordering of all inputs and outputs are:: Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs] e.g. If following module is exported:: class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output Resulting Graph would be:: graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1) Resulting ExportGraphSignature would be:: ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] ) """input_specs:List[InputSpec]output_specs:List[OutputSpec]# A list of parameters uniquely identified by mangled fully qualified name@propertydefparameters(self)->Collection[str]:returntuple(s.targetforsinself.input_specsifs.kind==InputKind.PARAMETERifisinstance(s.target,str))# A list of buffers uniquely identified by mangled fully qualified name@propertydefbuffers(self)->Collection[str]:returntuple(s.targetforsinself.input_specsifs.kind==InputKind.BUFFERifisinstance(s.target,str))@propertydefnon_persistent_buffers(self)->Collection[str]:returntuple(s.targetforsinself.input_specsifs.kind==InputKind.BUFFERifs.persistentisFalseifisinstance(s.target,str))# A list of lifted constant tensors@propertydeflifted_tensor_constants(self)->Collection[str]:returntuple(s.targetforsinself.input_specsifs.kind==InputKind.CONSTANT_TENSORifisinstance(s.target,str))@propertydeflifted_custom_objs(self)->Collection[str]:returntuple(s.targetforsinself.input_specsifs.kind==InputKind.CUSTOM_OBJifisinstance(s.target,str))# Graph node names of pytree-flattened inputs of original program@propertydefuser_inputs(self)->Collection[Union[int,float,bool,None,str]]:user_inputs:List[Union[int,float,bool,None,str]]=[]forsinself.input_specs:ifs.kind!=InputKind.USER_INPUT:continueifisinstance(s.arg,(TensorArgument,SymIntArgument,SymFloatArgument,SymBoolArgument,CustomObjArgument,),):user_inputs.append(s.arg.name)elifisinstance(s.arg,ConstantArgument):user_inputs.append(s.arg.value)else:raiseRuntimeError(f"{s.arg} is not a valid user inputs")returntuple(user_inputs)# Graph node names of pytree-flattened outputs of original program# For joint-graph purposes, will include the loss output.@propertydefuser_outputs(self)->Collection[Union[int,float,bool,None,str]]:user_outputs:List[Union[int,float,bool,None,str]]=[]forsinself.output_specs:ifs.kindnotin[OutputKind.USER_OUTPUT,OutputKind.LOSS_OUTPUT,]:continueifisinstance(s.arg,(TensorArgument,SymIntArgument,SymFloatArgument,SymBoolArgument),):user_outputs.append(s.arg.name)elifisinstance(s.arg,ConstantArgument):user_outputs.append(s.arg.value)elifisinstance(s.arg,CustomObjArgument):user_outputs.append(s.arg.name)else:raiseRuntimeError(f"{s.arg} is not a valid user output")returntuple(user_outputs)# A dictionary mapping graph input node names to parameters. If a graph input# name is found in this dictionary, it is guranteed to be a lifted parameter.@propertydefinputs_to_parameters(self)->Mapping[str,str]:return_immutable_dict((s.arg.name,s.target)forsinself.input_specsifs.kind==InputKind.PARAMETERandisinstance(s.arg,TensorArgument)andisinstance(s.target,str))# A dictionary mapping graph input node names to buffers. If a graph input# name is found in this dictionary, it is guranteed to be a lifted buffer.@propertydefinputs_to_buffers(self)->Mapping[str,str]:return_immutable_dict((s.arg.name,s.target)# type: ignore[union-attr, misc]forsinself.input_specsifs.kind==InputKind.BUFFERandisinstance(s.arg,TensorArgument)andisinstance(s.target,str))# A dictionary mapping graph output node names to buffers that are mutated in the# original program. Buffers that are not mutated will not be found in this dictionary.@propertydefbuffers_to_mutate(self)->Mapping[str,str]:return_immutable_dict((s.arg.name,s.target)forsinself.output_specsifs.kind==OutputKind.BUFFER_MUTATIONandisinstance(s.arg,TensorArgument)andisinstance(s.target,str))@propertydefuser_inputs_to_mutate(self)->Mapping[str,str]:return_immutable_dict((s.arg.name,s.target)forsinself.output_specsifs.kind==OutputKind.USER_INPUT_MUTATIONandisinstance(s.arg,TensorArgument)andisinstance(s.target,str))# A dictionary mapping graph input node names to lifted tensor constants.@propertydefinputs_to_lifted_tensor_constants(self)->Mapping[str,str]:return_immutable_dict((s.arg.name,s.target)forsinself.input_specsifs.kind==InputKind.CONSTANT_TENSORandisinstance(s.arg,TensorArgument)andisinstance(s.target,str))@propertydefinputs_to_lifted_custom_objs(self)->Mapping[str,str]:return_immutable_dict((s.arg.name,s.target)forsinself.input_specsifs.kind==InputKind.CUSTOM_OBJandisinstance(s.arg,CustomObjArgument)andisinstance(s.target,str))@propertydefbackward_signature(self)->Optional[ExportBackwardSignature]:loss_output=Nonegradients_to_parameters:Dict[str,str]={}gradients_to_user_inputs:Dict[str,str]={}forspecinself.output_specs:ifspec.kind==OutputKind.LOSS_OUTPUT:assertloss_outputisNoneassertisinstance(spec.arg,TensorArgument)loss_output=spec.arg.nameelifspec.kind==OutputKind.GRADIENT_TO_PARAMETER:assertisinstance(spec.target,str)assertisinstance(spec.arg,TensorArgument)gradients_to_parameters[spec.arg.name]=spec.targetelifspec.kind==OutputKind.GRADIENT_TO_USER_INPUT:assertisinstance(spec.target,str)assertisinstance(spec.arg,TensorArgument)gradients_to_user_inputs[spec.arg.name]=spec.targetifloss_outputisNone:returnNonereturnExportBackwardSignature(loss_output=loss_output,gradients_to_parameters=gradients_to_parameters,gradients_to_user_inputs=gradients_to_user_inputs,)# Map from assertion dependency token index to assertion dep token output# name in output. The shape of output after aot_autograd will be like:# (updated_inputs, user_outputs, dep_token).@propertydefassertion_dep_token(self)->Optional[Mapping[int,str]]:returnNone@propertydefinput_tokens(self)->Collection[str]:input_tokens=[]forsinself.input_specs:ifs.kind==InputKind.TOKEN:assertisinstance(s.arg,TokenArgument)input_tokens.append(s.arg.name)returntuple(input_tokens)@propertydefoutput_tokens(self)->Collection[str]:output_tokens=[]forsinself.output_specs:ifs.kind==OutputKind.TOKEN:assertisinstance(s.arg,TokenArgument)output_tokens.append(s.arg.name)returntuple(output_tokens)def__post_init__(self)->None:assertion_dep_token=self.assertion_dep_tokenifassertion_dep_tokenisNone:returnassertlen(assertion_dep_token)==1assertion_dep_token_index=next(iter(assertion_dep_token.keys()))assert(len(self.user_outputs)+len(self.buffers_to_mutate)==assertion_dep_token_index)
[docs]defreplace_all_uses(self,old:str,new:str):""" Replace all uses of the old name with new name in the signature. """assertisinstance(old,str)assertisinstance(new,str)arg_types=(TensorArgument,SymIntArgument,SymFloatArgument,SymBoolArgument,CustomObjArgument,TokenArgument,)foroinself.output_specs:ifisinstance(o.arg,arg_types):ifo.arg.name==old:o.arg.name=newforiinself.input_specs:ifisinstance(i.arg,arg_types):ifi.arg.name==old:i.arg.name=new
def_immutable_dict(items):""" Creates a mapping where items cannot be added, deleted, or updated. NOTE: The immutability is shallow (like tuple is an immutable collection). """fromtypesimportMappingProxyTypereturnMappingProxyType(dict(items))def_make_argument_spec(node,token_names)->ArgumentSpec:fromtorchimportScriptObject,SymBool,SymFloat,SymIntfromtorch._library.fake_class_registryimportFakeScriptObjectfromtorch._subclasses.fake_tensorimportFakeTensorifisinstance(node,(int,bool,float,type(None),str)):# For const outputs we just directly return thisreturnConstantArgument(name="",value=node)assert("val"innode.meta),f"{node} is not a constant or a node with a 'val' metadata field"val=node.meta["val"]ifnode.nameintoken_names:returnTokenArgument(name=node.name)elifisinstance(val,FakeTensor):returnTensorArgument(name=node.name)elifisinstance(val,SymInt):returnSymIntArgument(name=node.name)elifisinstance(val,SymFloat):returnSymFloatArgument(name=node.name)elifisinstance(val,SymBool):returnSymBoolArgument(name=node.name)elifisinstance(val,ScriptObject):returnCustomObjArgument(name=node.name,class_fqn=val._type().qualified_name())# type: ignore[attr-defined]elifisinstance(val,FakeScriptObject):returnCustomObjArgument(name=node.name,class_fqn=val.script_class_name,fake_val=val)elifisinstance(val,(int,bool,str,float,type(None))):returnConstantArgument(name=node.name,value=val)else:raiseAssertionError(f"Encountered an unsupported object of type {type(val)} "f"while writing the metadata for exported program")def_convert_to_export_graph_signature(graph_signature:"GraphSignature",gm:"torch.fx.GraphModule",non_persistent_buffers:Set[str],)->"ExportGraphSignature":fromtorch.utilsimport_pytreeaspytreeis_joint=graph_signature.backward_signatureisnotNone# unpack objectsuser_inputs=set(graph_signature.user_inputs)inputs_to_parameters=graph_signature.inputs_to_parametersinputs_to_buffers=graph_signature.inputs_to_buffersuser_outputs=set(graph_signature.user_outputs)buffer_mutations=graph_signature.buffers_to_mutateuser_input_mutations=graph_signature.user_inputs_to_mutategrad_params=graph_signature.backward_signature.gradients_to_parameterifis_jointelse{}# type: ignore[union-attr]grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputsifis_jointelse{}# type: ignore[union-attr]loss_output=graph_signature.backward_signature.loss_outputifis_jointelseNone# type: ignore[union-attr]input_tokens=graph_signature.input_tokensoutput_tokens=graph_signature.output_tokensinputs=[_make_argument_spec(node,input_tokens)fornodeingm.graph.nodesifnode.op=="placeholder"]outputs=[_make_argument_spec(node,output_tokens)fornodeinpytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)]defto_input_spec(inp:ArgumentSpec)->InputSpec:ifisinstance(inp,TokenArgument):returnInputSpec(kind=InputKind.TOKEN,arg=inp,target=None)ifnotisinstance(inp,TensorArgument):returnInputSpec(kind=InputKind.USER_INPUT,arg=inp,target=None)name=inp.nameifnameinuser_inputs:returnInputSpec(kind=InputKind.USER_INPUT,arg=inp,target=None)elifnameininputs_to_parameters:returnInputSpec(kind=InputKind.PARAMETER,arg=inp,target=inputs_to_parameters[name],# type: ignore[index])elifnameininputs_to_buffers:returnInputSpec(kind=InputKind.BUFFER,arg=inp,target=inputs_to_buffers[name],# type: ignore[index]persistent=(inputs_to_buffers[name]notinnon_persistent_buffers),# type: ignore[index])else:raiseAssertionError(f"Unknown tensor input kind: {name}")defto_output_spec(idx:int,o:ArgumentSpec)->OutputSpec:ifisinstance(o,TokenArgument):returnOutputSpec(kind=OutputKind.TOKEN,arg=o,target=None)ifnotisinstance(o,TensorArgument):returnOutputSpec(kind=OutputKind.USER_OUTPUT,arg=o,target=None)name=o.nameifidx<len(buffer_mutations)+len(user_input_mutations)+len(output_tokens):ifnameinbuffer_mutations:returnOutputSpec(kind=OutputKind.BUFFER_MUTATION,arg=o,target=buffer_mutations[name],# type: ignore[index])elifnameinuser_input_mutations:returnOutputSpec(kind=OutputKind.USER_INPUT_MUTATION,arg=o,target=user_input_mutations[name],# type: ignore[index])else:raiseAssertionError(f"Unknown tensor mutation kind: {name}")else:ifnameinuser_outputs:returnOutputSpec(kind=OutputKind.USER_OUTPUT,arg=o,target=None)elifnameingrad_params:returnOutputSpec(kind=OutputKind.GRADIENT_TO_PARAMETER,arg=o,target=grad_params[name],)elifnameingrad_user_inputs:returnOutputSpec(kind=OutputKind.GRADIENT_TO_USER_INPUT,arg=o,target=grad_user_inputs[name],)elifname==loss_output:returnOutputSpec(kind=OutputKind.LOSS_OUTPUT,arg=o,target=None)else:raiseAssertionError(f"Unknown tensor output kind: {name}")input_specs=[to_input_spec(inp)forinpininputs]output_specs=[to_output_spec(idx,o)foridx,oinenumerate(outputs)]returnExportGraphSignature(input_specs=input_specs,output_specs=output_specs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.