"""ONNX operators as native torch.fx operators.This module provides a set of functions to create ONNX operators in the FX graphwhich are exportable to ONNX."""from__future__importannotationsfromtypingimportTYPE_CHECKINGimporttorchfromtorch.onnx.opsimport_symbolic_implifTYPE_CHECKING:fromcollections.abcimportSequence# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597_TORCH_DTYPE_TO_ONNX_DTYPE={torch.float32:1,# FLOATtorch.uint8:2,# UINT8torch.int8:3,# INT8torch.uint16:4,# UINT16torch.int16:5,# INT16torch.int32:6,# INT32torch.int64:7,# INT64str:8,# STRINGtorch.bool:9,# BOOLtorch.float16:10,# FLOAT16torch.double:11,# DOUBLEtorch.uint32:12,# UINT32torch.uint64:13,# UINT64torch.complex64:14,# COMPLEX64torch.complex128:15,# COMPLEX128torch.bfloat16:16,# BFLOAT16torch.float8_e4m3fn:17,# FLOAT8E4M3FNtorch.float8_e4m3fnuz:18,# FLOAT8E4M3FNUZtorch.float8_e5m2:19,# FLOAT8E5M2torch.float8_e5m2fnuz:20,# FLOAT8E5M2FNUZ}def_parse_domain_op_type(domain_op:str)->tuple[str,str]:splitted=domain_op.split("::",1)iflen(splitted)==1:domain=""op_type=splitted[0]else:domain=splitted[0]op_type=splitted[1]returndomain,op_type
[docs]defsymbolic(domain_op:str,/,inputs:Sequence[torch.Tensor|None],attrs:dict[str,int|float|str|bool|Sequence[int]|Sequence[float]|Sequence[str]|Sequence[bool],]|None=None,*,dtype:torch.dtype|int,shape:Sequence[int|torch.SymInt],version:int|None=None,metadata_props:dict[str,str]|None=None,)->torch.Tensor:"""Create a symbolic FX operator to represent an arbitrary ONNX operator. This function is used to create a symbolic operator with a single output. To create an operator with multiple outputs, use :func:`symbolic_multi_out`. Example:: class CustomOp(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Normal torch operators can interleave with the symbolic ops during ONNX export x = x + 1 # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. # The output tensor will have the specified dtype and shape val = torch.onnx.ops.symbolic( "custom_domain::CustomOp", (x,), dict(attr_key="attr_value"), dtype=x.dtype, shape=x.shape, version=1, ) # The result of the symbolic op can be used in normal torch operations during ONNX export return torch.nn.functional.relu(val) # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True). Args: domain_op: The domain and operator name, separated by "::". For example, "custom_domain::CustomOp". inputs: The input tensors to the operator. attrs: The attributes of the operator. The keys are attribute names and the values are attribute values. Valid attribute types are int, float, str, bool, and lists of int, float, str, and bool. Tensor attributes are unsupported. dtype: The data type of the output tensor.This can be either a torch.dtype or an integer representing the ONNX data type. shape: The shape of the output tensor. This can be a list of integers or SymInt values. version: The version of the opset used for the operator. metadata_props: Metadata properties for the ONNX node. This is a dictionary of str-str pairs. Returns: The output tensor of the operator. """ifnotisinstance(dtype,int):torch._check(dtypein_TORCH_DTYPE_TO_ONNX_DTYPE,lambda:f"Unsupported dtype: {dtype}")dtype=_TORCH_DTYPE_TO_ONNX_DTYPE[dtype]domain,op_type=_parse_domain_op_type(domain_op)ifattrsisNone:attrs={}encoded_attrs=_symbolic_impl.EncodedAttrs.from_dict(attrs)# TODO: Parse domainreturn_symbolic_impl._symbolic(inputs,op_type,dtype,shape=shape,attr_keys=encoded_attrs.attr_keys,attr_types=encoded_attrs.attr_types,attr_pos=encoded_attrs.attr_pos,attr_ints=encoded_attrs.attr_ints,attr_floats=encoded_attrs.attr_floats,attr_strs=encoded_attrs.attr_strs,metadata_props_keys=metadata_props.keys()ifmetadata_propselse[],metadata_props_values=metadata_props.values()ifmetadata_propselse[],domain=domain,version=version,)
[docs]defsymbolic_multi_out(domain_op:str,/,inputs:Sequence[torch.Tensor|None],attrs:dict[str,int|float|str|bool|Sequence[int]|Sequence[float]|Sequence[str]|Sequence[bool],]|None=None,*,dtypes:Sequence[torch.dtype|int],shapes:Sequence[Sequence[int|torch.SymInt]],version:int|None=None,metadata_props:dict[str,str]|None=None,)->Sequence[torch.Tensor]:"""Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs. Example:: class CustomOp(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Normal torch operators can interleave with the symbolic ops during ONNX export x = x + 1 # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. # The output tensors will have the specified dtypes and shapes (out1, out2) = torch.onnx.ops.symbolic( "custom_domain::CustomOp", (x,), dict(attr_key="attr_value"), dtypes=(x.dtype, torch.float32), shapes=(x.shape, [1, 2, 3]), version=1, ) # The result of the symbolic op can be used in normal torch operations during ONNX export return torch.nn.functional.relu(out1 + out2) # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True). Args: domain_op: The domain and operator name, separated by "::". For example, "custom_domain::CustomOp". inputs: The input tensors to the operator. attrs: The attributes of the operator. The keys are attribute names and the values are attribute values. Valid attribute types are int, float, str, bool, and lists of int, float, str, and bool. Tensor attributes are unsupported. dtypes: The data types of the output tensors. This can be a list of torch.dtype or integers representing the ONNX data types. The length of this list must be the number of outputs. shapes: The shapes of the output tensors. This can be a list of lists of integers or SymInt values. The length of this list must be the number of outputs. version: The version of the opset used for the operator. metadata_props: Metadata properties for the ONNX node. This is a dictionary of str-str pairs. Returns: A list of output tensors of the operator. """torch._check(len(shapes)==len(dtypes),lambda:f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})",)onnx_dtypes=[]fordtypeindtypes:ifnotisinstance(dtype,int):torch._check(dtypein_TORCH_DTYPE_TO_ONNX_DTYPE,lambda:f"Unsupported dtype: {dtype}",)onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype])else:onnx_dtypes.append(dtype)domain,op_type=_parse_domain_op_type(domain_op)ifattrsisNone:attrs={}encoded_attrs=_symbolic_impl.EncodedAttrs.from_dict(attrs)# Use the size of dtypes to determine the number of outputsreturn_symbolic_impl._symbolic_multi_out(inputs,op_type,onnx_dtypes,shapes=shapes,attr_keys=encoded_attrs.attr_keys,attr_types=encoded_attrs.attr_types,attr_pos=encoded_attrs.attr_pos,attr_ints=encoded_attrs.attr_ints,attr_floats=encoded_attrs.attr_floats,attr_strs=encoded_attrs.attr_strs,metadata_props_keys=metadata_props.keys()ifmetadata_propselse[],metadata_props_values=metadata_props.values()ifmetadata_propselse[],domain=domain,version=version,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.