Source code for torch.onnx

import functools
import types

import torch._C as _C

TensorProtoDataType = _C._onnx.TensorProtoDataType

ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"


class ExportTypes:
    PROTOBUF_FILE = 1
    ZIP_ARCHIVE = 2
    COMPRESSED_ZIP_ARCHIVE = 3
    DIRECTORY = 4


def _export(*args, **kwargs):
    from torch.onnx import utils
    return utils._export(*args, **kwargs)


[docs]def export(*args, **kwargs): from torch.onnx import utils return utils.export(*args, **kwargs)
def _optimize_trace(trace, aten): from torch.onnx import utils trace.set_graph(utils._optimize_graph(trace.graph(), aten)) def set_training(*args, **kwargs): from torch.onnx import utils return utils.set_training(*args, **kwargs) def _run_symbolic_function(*args, **kwargs): from torch.onnx import utils return utils._run_symbolic_function(*args, **kwargs) def _run_symbolic_method(*args, **kwargs): from torch.onnx import utils return utils._run_symbolic_method(*args, **kwargs) def _symbolic_override_wrapper_maker(symbolic_fn, might_trace, fn): def wrapper(*args, **kwargs): import torch import torch.jit from torch.autograd import Function, function # fast pass if not might_trace(args): return fn(*args, **kwargs) flat_args = tuple(function._iter_tensors_permissive(args)) flat_args_only_tensors = tuple(t for t in flat_args if isinstance(t, torch.Tensor)) if not any(map(torch._C._jit_is_tracing, flat_args_only_tensors)): return fn(*args, **kwargs) tstate = torch._C._get_tracing_state(flat_args_only_tensors) arg_values = [torch._C._get_value_trace(tstate, x) if isinstance(x, torch.Tensor) else x for x in flat_args] # This must come after the calls to get_value_trace, lest we # lose information due to in-place operations. output_vars = fn(*args, **kwargs) symbolic_args = function._unflatten(arg_values, args) output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs) for var, val in zip( function._iter_tensors(output_vars), function._iter_jit_values(output_vals)): val.inferTypeFrom(var.data) torch._C._set_value_trace(tstate, var, val) return output_vars # fn might be autograd.Function too, in this case wrapping doesn't work if isinstance(fn, types.FunctionType): wrapper = functools.wraps(fn)(wrapper) return wrapper def symbolic_override(symbolic_fn): r""" Decorator to override ONNX export of the a function with specified subgraph. Effectively allows to attach symbolic() implementation to an arbitrary python function or autograd.Function. Requirements for the decorated function: - being non-member function or autograd.Function - positional inputs are Tensors or (nested) lists or tuples of them (similar requirement to NestedIOFunction) - outputs are similarly Tensors or (nested) lists or tuples of them - non-tensor typed values should be keyword arguments both in definition and when called Example usage: ``` def symb(g, x, y): return g.op('Sum', x, y[0], y[1]) @symbolic_override(symb) def foo(x, y): return x + y[0] + y[1] ``` """ return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, lambda x: True) def symbolic_override_first_arg_based(symbolic_fn): r""" Decorator to override ONNX export of the a function with specified subgraph. Equivalent to :func:`symbolic_override` but checks only the first argument of the function to figure out whether the tracing is on. Thus the first arg needs to be a Tensor. """ def might_trace(args): import torch first_arg = args[0] if not isinstance(first_arg, torch.Tensor): raise ValueError('First argument of {} is expected to be a tensor, ' 'but got an object of type {}' .format(symbolic_fn.__name__, type(first_arg))) return torch._C._jit_is_tracing(first_arg) return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, might_trace) def symbolic_override_packed_sequence_based(symbolic_fn): r""" Decorator to override ONNX export of the a function with specified subgraph. Equivalent to :func:`symbolic_override` but checks only the first argument of the function to figure out whether the tracing is on. Thus the first arg needs to be a Tensor. """ def might_trace(args): import torch first_arg = args[0] if not isinstance(first_arg, torch.nn.utils.rnn.PackedSequence): raise ValueError('pad_packed_sequence expects sequence to be a ' 'PackedSequence, but got an object of type {}' .format(type(first_arg))) return torch._C._jit_is_tracing(first_arg[0]) return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, might_trace)