Shortcuts

Source code for torch.onnx.ops

"""ONNX operators as native torch.fx operators.

This module provides a set of functions to create ONNX operators in the FX graph
which are exportable to ONNX.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch.onnx.ops import _symbolic_impl


if TYPE_CHECKING:
    from collections.abc import Sequence


# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597
_TORCH_DTYPE_TO_ONNX_DTYPE = {
    torch.float32: 1,  # FLOAT
    torch.uint8: 2,  # UINT8
    torch.int8: 3,  # INT8
    torch.uint16: 4,  # UINT16
    torch.int16: 5,  # INT16
    torch.int32: 6,  # INT32
    torch.int64: 7,  # INT64
    str: 8,  # STRING
    torch.bool: 9,  # BOOL
    torch.float16: 10,  # FLOAT16
    torch.double: 11,  # DOUBLE
    torch.uint32: 12,  # UINT32
    torch.uint64: 13,  # UINT64
    torch.complex64: 14,  # COMPLEX64
    torch.complex128: 15,  # COMPLEX128
    torch.bfloat16: 16,  # BFLOAT16
    torch.float8_e4m3fn: 17,  # FLOAT8E4M3FN
    torch.float8_e4m3fnuz: 18,  # FLOAT8E4M3FNUZ
    torch.float8_e5m2: 19,  # FLOAT8E5M2
    torch.float8_e5m2fnuz: 20,  # FLOAT8E5M2FNUZ
}


def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
    splitted = domain_op.split("::", 1)
    if len(splitted) == 1:
        domain = ""
        op_type = splitted[0]
    else:
        domain = splitted[0]
        op_type = splitted[1]
    return domain, op_type


[docs]def symbolic( domain_op: str, /, inputs: Sequence[torch.Tensor | None], attrs: dict[ str, int | float | str | bool | Sequence[int] | Sequence[float] | Sequence[str] | Sequence[bool], ] | None = None, *, dtype: torch.dtype | int, shape: Sequence[int | torch.SymInt], version: int | None = None, metadata_props: dict[str, str] | None = None, ) -> torch.Tensor: """Create a symbolic FX operator to represent an arbitrary ONNX operator. This function is used to create a symbolic operator with a single output. To create an operator with multiple outputs, use :func:`symbolic_multi_out`. Example:: class CustomOp(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Normal torch operators can interleave with the symbolic ops during ONNX export x = x + 1 # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. # The output tensor will have the specified dtype and shape val = torch.onnx.ops.symbolic( "custom_domain::CustomOp", (x,), dict(attr_key="attr_value"), dtype=x.dtype, shape=x.shape, version=1, ) # The result of the symbolic op can be used in normal torch operations during ONNX export return torch.nn.functional.relu(val) # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True). Args: domain_op: The domain and operator name, separated by "::". For example, "custom_domain::CustomOp". inputs: The input tensors to the operator. attrs: The attributes of the operator. The keys are attribute names and the values are attribute values. Valid attribute types are int, float, str, bool, and lists of int, float, str, and bool. Tensor attributes are unsupported. dtype: The data type of the output tensor.This can be either a torch.dtype or an integer representing the ONNX data type. shape: The shape of the output tensor. This can be a list of integers or SymInt values. version: The version of the opset used for the operator. metadata_props: Metadata properties for the ONNX node. This is a dictionary of str-str pairs. Returns: The output tensor of the operator. """ if not isinstance(dtype, int): torch._check( dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}" ) dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype] domain, op_type = _parse_domain_op_type(domain_op) if attrs is None: attrs = {} encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs) # TODO: Parse domain return _symbolic_impl._symbolic( inputs, op_type, dtype, shape=shape, attr_keys=encoded_attrs.attr_keys, attr_types=encoded_attrs.attr_types, attr_pos=encoded_attrs.attr_pos, attr_ints=encoded_attrs.attr_ints, attr_floats=encoded_attrs.attr_floats, attr_strs=encoded_attrs.attr_strs, metadata_props_keys=metadata_props.keys() if metadata_props else [], metadata_props_values=metadata_props.values() if metadata_props else [], domain=domain, version=version, )
[docs]def symbolic_multi_out( domain_op: str, /, inputs: Sequence[torch.Tensor | None], attrs: dict[ str, int | float | str | bool | Sequence[int] | Sequence[float] | Sequence[str] | Sequence[bool], ] | None = None, *, dtypes: Sequence[torch.dtype | int], shapes: Sequence[Sequence[int | torch.SymInt]], version: int | None = None, metadata_props: dict[str, str] | None = None, ) -> Sequence[torch.Tensor]: """Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs. Example:: class CustomOp(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Normal torch operators can interleave with the symbolic ops during ONNX export x = x + 1 # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. # The output tensors will have the specified dtypes and shapes (out1, out2) = torch.onnx.ops.symbolic( "custom_domain::CustomOp", (x,), dict(attr_key="attr_value"), dtypes=(x.dtype, torch.float32), shapes=(x.shape, [1, 2, 3]), version=1, ) # The result of the symbolic op can be used in normal torch operations during ONNX export return torch.nn.functional.relu(out1 + out2) # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True). Args: domain_op: The domain and operator name, separated by "::". For example, "custom_domain::CustomOp". inputs: The input tensors to the operator. attrs: The attributes of the operator. The keys are attribute names and the values are attribute values. Valid attribute types are int, float, str, bool, and lists of int, float, str, and bool. Tensor attributes are unsupported. dtypes: The data types of the output tensors. This can be a list of torch.dtype or integers representing the ONNX data types. The length of this list must be the number of outputs. shapes: The shapes of the output tensors. This can be a list of lists of integers or SymInt values. The length of this list must be the number of outputs. version: The version of the opset used for the operator. metadata_props: Metadata properties for the ONNX node. This is a dictionary of str-str pairs. Returns: A list of output tensors of the operator. """ torch._check( len(shapes) == len(dtypes), lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})", ) onnx_dtypes = [] for dtype in dtypes: if not isinstance(dtype, int): torch._check( dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}", ) onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype]) else: onnx_dtypes.append(dtype) domain, op_type = _parse_domain_op_type(domain_op) if attrs is None: attrs = {} encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs) # Use the size of dtypes to determine the number of outputs return _symbolic_impl._symbolic_multi_out( inputs, op_type, onnx_dtypes, shapes=shapes, attr_keys=encoded_attrs.attr_keys, attr_types=encoded_attrs.attr_types, attr_pos=encoded_attrs.attr_pos, attr_ints=encoded_attrs.attr_ints, attr_floats=encoded_attrs.attr_floats, attr_strs=encoded_attrs.attr_strs, metadata_props_keys=metadata_props.keys() if metadata_props else [], metadata_props_values=metadata_props.values() if metadata_props else [], domain=domain, version=version, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources