Source code for torch.jit._script
"""TorchScript
This module contains functionality to support the JIT's scripting frontend, notably:
- torch.jit.script
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import functools
import collections
import enum
import inspect
import copy
import pickle
import warnings
from typing import Any, Dict, List, Tuple, Union, Callable
import torch
import torch._jit_internal as _jit_internal
from torch.utils import set_module
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_methods_to_compile, _compile_and_register_class
from torch.nn import Module
from torch.jit._state import _enabled
from torch.jit._builtins import _register_builtin
from torch._six import with_metaclass
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
from torch._jit_internal import _qualified_name
from torch.jit._fuser import _graph_for
from torch.jit._state import (
_try_get_jit_cached_function,
_try_get_jit_cached_overloads,
_set_jit_function_cache,
_set_jit_overload_cache,
)
from torch.overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic)
from torch.package import PackageExporter, PackageImporter
from ._serialization import validate_map_location
from torch.jit._monkeytype_config import (
monkeytype_trace,
JitTypeTraceConfig ,
JitTypeTraceStore
)
from torch._classes import classes
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
torch._C.ScriptMethod.graph_for = _graph_for # type: ignore[attr-defined]
torch._C.ScriptFunction.graph_for = _graph_for # type: ignore[attr-defined]
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")
if _enabled:
Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:
def Attribute(value, type): # type: ignore[no-redef]
return value
Attribute.__doc__ = """
This method is a pass-through function that returns `value`, mostly
used to indicate to the TorchScript compiler that the left-hand side
expression is a class instance attribute with type of `type`. Note that
`torch.jit.Attribute` should only be used in `__init__` method of `nn.Module`
subclasses.
Though TorchScript can infer correct type for most Python expressions, there are some cases where
type inference can be wrong, including:
- Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
- Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
it is type `T` rather than `Optional[T]`
In eager mode, it is simply a pass-through function that returns `value`
without other implications.
Example:
.. testcode::
import torch
from typing import Dict
class AttributeModule(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.foo = torch.jit.Attribute(0.1, float)
# we should be able to use self.foo as a float here
assert 0.0 < self.foo
self.names_ages = torch.jit.Attribute({}, Dict[str, int])
self.names_ages["someone"] = 20
assert isinstance(self.names_ages["someone"], int)
m = AttributeModule()
# m will contain two attributes
# 1. foo of type float
# 2. names_ages of type Dict[str, int]
.. testcleanup::
del AttributeModule
del m
Args:
value: An initial value to be assigned to attribute.
type: A Python type
Returns:
Returns `value`
"""
def _get_type_trace_db():
# This is a private API. Use of this for external purposes is discouraged.
return type_trace_db
# Gets a function from the name of a method on a type
def _get_function_from_type(cls, name):
return getattr(cls, name, None)
# ScriptClasses must be new-style classes because we construct them using their
# __new__ method.
def _is_new_style_class(cls):
if hasattr(cls, "__class__"):
return "__dict__" in dir(cls) or hasattr(cls, "__slots__")
# These OrderedDictWrapper classes replace the actual OrderedDicts in
# module with versions that get/set properties inside of Module.
# This allows us to reuse most of nn.Module while still storing the
# data in C++.
# Each OrderedDict needs to support:
# x not in view
# x in view
# view[name] = ...
# view.values()
# del view[name]
# view.items()
# view.keys()
# len(view)
class OrderedDictWrapper(object):
def __init__(self, _c):
self._c = _c
def keys(self):
return [k for k, v in self.items()]
def values(self):
return [v for k, v in self.items()]
def __len__(self):
return len(self.values())
def __delitem__(self, k):
raise RuntimeError("cannot delete methods or parameters of a script module")
def items(self):
return self._c.items()
def __setitem__(self, k, v):
if k not in self:
raise RuntimeError(
"Can't add a new parameter after ScriptModule construction."
" Tried to add '{}".format(k)
)
self._c.setattr(k, v)
def __contains__(self, k):
return self._c.contains(k)
def __getitem__(self, k):
if k not in self:
raise KeyError(k)
return self._c.getattr(k)
class OrderedModuleDict(OrderedDictWrapper):
def __init__(self, module, python_dict):
super(OrderedModuleDict, self).__init__(torch._C.ModuleDict(module))
# contains _both_ script modules and non-script python-only modules
# because script modules are subclassed in python and the
# C++ Module class will not hold references to them,
# to ensure that you always get the same python value here
# we store it in the python dict as well
self._python_modules = python_dict
def items(self):
r = self._python_modules.items()
return r
def __contains__(self, k):
return k in self._python_modules
def __setitem__(self, k, v):
# Cases where sub-module can be re-assigned after ScriptModule construction
# 1. If the attr is an module interface type, it's guaranteed that the module is
# not inlined in the graph, so it's safe to swap a new ScriptModule in.
# 2. if the new value if a ScriptModule with the same JIT type, IR won't change
# and it's legit to swap a new module in.
# In these two cases we allow swapping a new scripted module and update the
# corresponding python module dict to keep sync.
# Note: the value to be swapped in has to be ScriptModule instead of nn.Module,
# otherwise it's illegal and we throw error.
if isinstance(v, ScriptModule):
self._c.setattr(k, v)
self._python_modules[k] = v
else:
raise RuntimeError(
"Cannot re-assign modules in a ScriptModule with non-scripted "
"module, tried to replace existing module '{}': {}".format(k, v)
)
def __getitem__(self, k):
return self._python_modules[k]
# For each user-defined class that subclasses ScriptModule, this meta-class:
# (1) finds all the methods annotated with @script_method in a ScriptModule and
# removes them from the class attributes
# (2) puts a wrapper around the class's __init__ method to recursively compile
# all of the script_methods with the module after the original __init__ has
# run. This has to occur after the user-defined __init__ so that submodules and
# parameters are initialized _before_ the script compiler resolve references to
# `self.param` or `self.module`.
class ScriptMeta(type):
def __init__(cls, name, bases, attrs): # noqa: B902
# Aggregate all the ScriptMethods and constants from superclasses
cls._methods: Dict[str, Any] = {}
cls._constants_set = set(getattr(cls, "__constants__", ()))
for base in reversed(bases):
for k, v in getattr(base, "_methods", {}).items():
cls._methods[k] = v
base_constants = getattr(base, "_constants_set", set())
cls._constants_set = cls._constants_set.union(base_constants)
# find all the script methods of the current class
for k, v in sorted(attrs.items()):
if isinstance(v, ScriptMethodStub):
delattr(cls, k)
cls._methods[v.original_method.__name__] = v
if getattr(cls, "_disable_script_meta", False):
# We leave built-in ScriptModule types alone, since this metaclass
# is only for compiling user classes that inherit from
# ScriptModule.
return super(ScriptMeta, cls).__init__(name, bases, attrs)
original_init = getattr(cls, "__init__", lambda self: None)
@functools.wraps(original_init)
def init_then_script(self, *args, **kwargs):
num_methods = len(cls._methods)
original_init(self, *args, **kwargs)
added_methods_in_init = len(cls._methods) > num_methods
if type(self) == cls:
def make_stubs(module):
cls = type(module)
if hasattr(cls, "_methods"):
return [v for k, v in sorted(cls._methods.items())]
else:
return infer_methods_to_compile(module)
self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
# Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find
# the scripted versions.
concrete_type = self._actual_script_module._concrete_type
for name in concrete_type.get_attributes():
delattr(self, name)
for name, _ in concrete_type.get_modules():
delattr(self, name)
for name in ("_parameters", "_buffers", "_modules"):
delattr(self, name)
cls.__init__ = init_then_script # type: ignore[misc]
super(ScriptMeta, cls).__init__(name, bases, attrs)
class _CachedForward(object):
def __get__(self, obj, cls):
return self.__getattr__("forward") # type: ignore[attr-defined]
class ScriptWarning(Warning):
pass
def script_method(fn):
if not _enabled:
return fn
# NOTE: we need to traverse two frames here because the meta-class frame
# for ScriptModule will be present, as opposed to invoking @script on a
# a function or invoking define() on a CompilationUnit.
# The stack will look like:
#
# 0. createResolutionCallback()
# 1. script_method()
# 2. ScriptModule metaclass frame
# 3. Surrounding scope
#
# createResolutionCallback internally adds 1 to get us to the scope of this
# function (the calling function). Adding 2 gets us to the proper surrounding scope.
_rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
return ScriptMethodStub(_rcb, ast, fn)
class ConstMap:
def __init__(self, const_mapping):
self.const_mapping = const_mapping
def __getattr__(self, attr):
return self.const_mapping[attr]
def unpackage_script_module(importer: PackageImporter, script_module_id: str) -> torch.nn.Module:
"""
Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
Performs work of loading and returning a ScriptModule from a ``torch.package`` archive.
"""
if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader):
raise RuntimeError(
"Loading ScriptObjects from a PackageImporter created from a "
"directory is not supported. Use a package archive file instead."
)
cu = torch._C.CompilationUnit()
cpp_module = torch._C._import_ir_module_from_package(
cu,
importer.zip_reader,
importer.storage_context,
validate_map_location(importer.last_map_location),
script_module_id,
)
return wrap_cpp_module(cpp_module)
if _enabled:
_magic_methods = [
"__iter__",
"__len__",
"__neg__",
"__mul__",
"__contains__",
"__add__",
"__sub__",
"__pow__",
"__truediv__",
"__mod__",
"__ne__",
"__eq__",
"__lt__",
"__gt__",
"__le__",
"__ge__",
"__and__",
"__or__",
"__xor__",
"__getitem__",
"__setitem__",
"__call__",
"__int__",
"__float__",
"__bool__",
"__str__",
"__enter__",
"__exit__",
]
class RecursiveScriptClass(object):
"""
An analogue of RecursiveScriptModule for regular objects that are not modules.
This class is a wrapper around a torch._C.ScriptObject that represents an instance
of a TorchScript class and allows it to be used in Python.
Attributes:
_c [torch._C.ScriptObject]: The C++ object to which attribute lookups and method
calls are forwarded.
_props [Dict[str, property]]: A dictionary of properties fetched from self._c and
exposed on this wrppaer.
"""
def __init__(self, cpp_class):
super(RecursiveScriptClass, self).__init__()
self.__dict__["_initializing"] = True
self._c = cpp_class
# Add wrapped object's properties to this class instance.
self._props = {prop.name: property(prop.getter, prop.setter) for prop in self._c._properties()}
self.__dict__["_initializing"] = False
def __getattr__(self, attr):
if "_initializing" in self.__dict__ and self.__dict__["_initializing"]:
return super(RecursiveScriptClass, self).__getattr__(attr) # type: ignore[misc]
if attr in self._props:
return self._props[attr].fget()
return getattr(self._c, attr)
def __setattr__(self, attr, value):
if "_initializing" in self.__dict__ and self.__dict__["_initializing"]:
return super(RecursiveScriptClass, self).__setattr__(attr, value)
if attr in self._props:
return self._props[attr].fset(value)
setattr(self._c, attr, value)
# Delegate calls to magic methods like __len__ to the C++ module backing the
# RecursiveScriptClass.
def forward_magic_method(self, method_name, *args, **kwargs):
if not self._c._has_method(method_name):
raise TypeError()
self_method = self.__getattr__(method_name)
return self_method(*args, **kwargs)
def __getstate__(self):
raise pickle.PickleError("ScriptClasses cannot be pickled")
def __iadd__(self, other):
if self._c._has_method("__iadd__"):
return self.forward_magic_method("__iadd__", other)
else:
return self.forward_magic_method("__add__", other)
for method_name in _magic_methods:
def method_template(self, *args, **kwargs):
return self.forward_magic_method(method_name, *args, **kwargs)
setattr(RecursiveScriptClass, method_name, method_template)
# this is a Python 'non-data descriptor' that causes the first access
# to ScriptModule's forward to look up the forward method and stash
# it in the objects dict. Due to the standard rules for attribute lookup,
# subsequent lookups will just directly return the previously looked up method.
# This is necessary because nn.Module defines forward as a method. If we
# did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward
# which always throws an exception.
class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore[misc]
r"""
A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s
contain methods, attributes, parameters, and
constants. These can be accessed the same way as on a normal ``nn.Module``.
"""
__jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']
def __init__(self):
super(ScriptModule, self).__init__()
forward = _CachedForward()
def __getattr__(self, attr):
if "_actual_script_module" not in self.__dict__:
return super(ScriptModule, self).__getattr__(attr)
return getattr(self._actual_script_module, attr)
def __setattr__(self, attr, value):
if "_actual_script_module" not in self.__dict__:
# Unwrap torch.jit.Attribute into a regular setattr + record
# the provided type in __annotations__.
#
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
if isinstance(value, Attribute):
# NB: Ensure that we set __annotations__ on the specific
# class in question, and not on a superclass (which would
# be wrong wrong wrong!).
# See also https://github.com/pytorch/pytorch/issues/39463
if "__annotations__" not in self.__class__.__dict__:
self.__class__.__annotations__ = {}
self.__annotations__[attr] = value.type
value = value.value
return super(ScriptModule, self).__setattr__(attr, value)
setattr(self._actual_script_module, attr, value)
def define(self, src):
if "_actual_script_module" in self.__dict__:
# If we have completed initialization, just defer to the
# backing RecursiveScriptModule to eagerly compile the provided
# source.
return self._actual_script_module.define(src)
# Otherwise, we are still in the object's __init__.
# In that case, add `src` as a stub to be compiled.
#
# We use frames_up=1 to get to the proper surrounding scope. The stack
# will look like:
# 0. createResolutionCallback
# 1. define()
# 2. surrounding scope.
#
# createResolutionCallback internally adds 1 to get us to our frame, then
# we add 1 to get to the proper surrounding scope.
rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
ast = torch._C._parse_source_def(src)
self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None)
def _replicate_for_data_parallel(self):
return self._actual_script_module._replicate_for_data_parallel()
def __reduce_package__(self, exporter: PackageExporter):
"""
Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when
saving TorchScript objects. Performs act of saving a ScriptModule inside of
a ``torch.package`` archive.
Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s
Pickler's ``persistent_load`` function.
"""
script_module_id = exporter.get_unique_id()
exporter.script_module_serializer.serialize(self._c, int(script_module_id))
return (unpackage_script_module, (script_module_id,))
class RecursiveScriptModule(ScriptModule):
# XXX: RecursiveScriptModule inherits from ScriptModule for the sole
# reason that it retains the existing isinstance(ScriptModule)
# behavior.
r"""
The core data structure in TorchScript is the ``ScriptModule``. It is an
analogue of torch's ``nn.Module`` and represents an entire model as a tree of
submodules. Like normal modules, each individual module in a ``ScriptModule`` can
have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented
as Python functions, but in ``ScriptModule``\s methods are implemented as
TorchScript functions, a statically-typed subset of Python that contains all
of PyTorch's built-in Tensor operations. This difference allows your
``ScriptModule``\s code to run without the need for a Python interpreter.
``ScriptModule``\s should not be created manually, instead use
either :func:`tracing <torch.jit.trace>` or :func:`scripting <torch.jit.script>`.
Tracing and scripting can be applied incrementally and :ref:`composed as necessary <Types>`.
* Tracing records the tensor operations as executed with a set of example inputs and uses these
operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing,
but values other than Tensors and control flow aren't captured in the graph.
* Scripting inspects the Python code of the model
and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow.
Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary.
"""
_disable_script_meta = True
def __init__(self, cpp_module):
self.__dict__["_initializing"] = True
self._c = cpp_module
super(RecursiveScriptModule, self).__init__()
# Delete the 'training' attribute set up by `Module.__init__`. It
# will get set on the underlying cpp module, so we delete it here
# to avoid this version shadowing the cpp module version.
delattr(self, "training")
@staticmethod
def _construct(cpp_module, init_fn):
"""
Construct a RecursiveScriptModule that's ready for use. PyTorch
code should use this to construct a RecursiveScriptModule instead
of instead of calling `__init__` directly, as it makes sure the
object is properly finalized (and in the future, we may take
control of how the RecursiveScriptModule instance is created).
Args:
cpp_module: The C++ Module that will hold the actual state of
this RecursiveScriptModule instance.
init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
"""
script_module = RecursiveScriptModule(cpp_module)
init_fn(script_module)
# Finalize the ScriptModule: replace the nn.Module state with our
# custom implementations and flip the _initializing bit.
RecursiveScriptModule._finalize_scriptmodule(script_module)
return script_module
@staticmethod
def _finalize_scriptmodule(script_module):
script_module._parameters = OrderedDictWrapper(
torch._C.ParameterDict(script_module._c)
)
script_module._buffers = OrderedDictWrapper(
torch._C.BufferDict(script_module._c)
)
script_module._modules = OrderedModuleDict(
script_module._c, script_module._modules
)
script_module._initializing = False
def _reconstruct(self, cpp_module):
"""
Re-construct an instance of RecursiveScriptModule using an instance of a C++ module.
Args:
cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around.
"""
self.__init__(cpp_module) # type: ignore[misc]
# Copy the concrete type from the C++ module to this ScriptModule.
self._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
self._c._type()
)
# Copy submodules from the C++ module to this ScriptModule.
modules = {}
for name, cpp_module in torch._C.ModuleDict(self._c).items():
modules[name] = wrap_cpp_module(cpp_module)
self._modules = OrderedModuleDict(self._c, modules)
# Copy parameters and buffers.
self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c))
self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c))
# Get rid of the functions from the old C++ module.
self.__dict__ = {
k: v
for k, v in self.__dict__.items()
if not isinstance(v, torch._C.ScriptMethod)
}
self.__dict__["_initializing"] = False
@property
def graph(self):
r"""
Returns a string representation of the internal graph for the
``forward`` method. See :ref:`interpreting-graphs` for details.
"""
return self._c._get_method("forward").graph
@property
def inlined_graph(self):
r"""
Returns a string representation of the internal graph for the
``forward`` method. This graph will be preprocessed to inline all function and method calls.
See :ref:`interpreting-graphs` for details.
"""
return self.forward.inlined_graph
@property
def code(self):
r"""
Returns a pretty-printed representation (as valid Python syntax) of
the internal graph for the ``forward`` method. See
:ref:`inspecting-code` for details.
"""
return self.forward.code
@property
def code_with_constants(self):
r"""
Returns a tuple of:
[0] a pretty-printed representation (as valid Python syntax) of
the internal graph for the ``forward`` method. See `code`.
[1] a ConstMap following the CONSTANT.cN format of the output in [0].
The indices in the [0] output are keys to the underlying constant's values.
See :ref:`inspecting-code` for details.
"""
r = self.forward.code_with_constants
return (r[0], ConstMap(r[1]))
def save(self, f, **kwargs):
r"""
save(f, _extra_files={})
See :func:`torch.jit.save <torch.jit.save>` for details.
"""
return self._c.save(str(f), **kwargs)
def _save_for_lite_interpreter(self, *args, **kwargs):
r"""
_save_for_lite_interpreter(f)
Add (or update) the bytecode session to the script model. The updated model is used
in lite interpreter for mobile applications.
Args:
f: a string containing a file name.
_extra_files: Map from filename to contents which will be stored as part of 'f'.
"""
return self._c._save_for_mobile(*args, **kwargs)
def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs):
return self._c._save_to_buffer_for_mobile(*args, **kwargs)
def save_to_buffer(self, *args, **kwargs):
return self._c.save_to_buffer(*args, **kwargs)
def get_debug_state(self, *args, **kwargs):
return self._c.get_debug_state()
def extra_repr(self):
return "original_name={}".format(self.original_name)
def graph_for(self, *args, **kwargs):
return self.forward.graph_for(*args, **kwargs)
@property
def original_name(self):
if type(self) == str(self._c._type().name()):
return ""
return str(self._c._type().name())
def define(self, src):
# We use frames_up=1 to get to the proper surrounding scope. The stack
# will look like:
# 0. createResolutionCallback
# 1. define()
# 2. surrounding scope.
#
# createResolutionCallback internally adds 1 to get us to our frame, then
# we add 1 to get to the proper surrounding scope.
rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
self._c._define(self._concrete_type, src, rcb)
def __getattr__(self, attr):
if "_initializing" not in self.__dict__:
raise RuntimeError(
"ScriptModule has not been initialized, did you forget to call super's init?"
)
if self._initializing:
return super(RecursiveScriptModule, self).__getattr__(attr)
# _modules check is before hasattr since modules are included as attributes in _c,
# but we want to get the python wrapper from _modules instead of the raw _c object.
if attr in self._modules:
return self._modules[attr]
elif self._c.hasattr(attr):
return self._c.getattr(attr)
elif self._c._has_method(attr):
script_method = self._c._get_method(attr)
# cache method so future calls do not go through __getattr__
# to improve invocation performance
self.__dict__[attr] = script_method
return script_method
return super(RecursiveScriptModule, self).__getattr__(attr)
def __setattr__(self, attr, value):
if self._initializing:
return super(RecursiveScriptModule, self).__setattr__(attr, value)
if attr in self._modules:
self._modules[attr] = value
elif self._c.hasattr(attr):
self._c.setattr(attr, value)
elif (
hasattr(self, "_concrete_type")
and attr in self._concrete_type.get_constants().keys()
):
# TODO: we don't have _concrete_type set after load(), and in general we lose constant information.
# We should encode constants as class type attributes (or something) so it persists across save/load.
raise AttributeError(
"Cannot mutate TorchScript constant value: '{}'. Value: '{}'".format(
attr, value
)
)
else:
# We allow setting Python attributes on the ScriptModule, for
# when people want to stash some convenience info on it.
# TODO: it's possible that the following is confusing:
# s = torch.jit.script(...)
# s.python_attr = ...
# s.save() <--- this doesn't have `python_attr`
# It's fairly trivial to save enough info to warn in this case.
return super(RecursiveScriptModule, self).__setattr__(attr, value)
def __copy__(self):
return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c))
def __deepcopy__(self, memo):
return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo))
# Python magic methods do method lookups on an object's class type, instead of looking up
# the method defines on the class instance. In order to continue to expose the magic methods
# of builtin-containers (ModuleList, Sequential, ModuleDict) to Python, we
# define magic methods here as a shim to the correct attribute.
def forward_magic_method(self, method_name, *args, **kwargs):
self_method = getattr(self, method_name)
if getattr(self_method, "__func__", None) == getattr(
RecursiveScriptModule, method_name
):
raise NotImplementedError()
return self_method(*args, **kwargs)
def __iter__(self):
return self.forward_magic_method("__iter__")
def __getitem__(self, idx):
return self.forward_magic_method("__getitem__", idx)
def __len__(self):
return self.forward_magic_method("__len__")
def __contains__(self, key):
return self.forward_magic_method("__contains__", key)
# dir is defined by the base nn.Module, so instead of throwing if
# it is not overridden, we call into the nn.Module __dir__ method
def __dir__(self):
self_method = self.__dir__
if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined]
RecursiveScriptModule, "__dir__"
):
return super(RecursiveScriptModule, self).__dir__()
return self_method()
# to resolve bool(value), Python looks if __bool__ is defined then __iter__
# is defined then returns true for classes. Since __iter__() on this
# class throws if it isn't overridden, we define __bool__ to preserve default behavior
def __bool__(self):
self_method = self.__bool__
if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined]
RecursiveScriptModule, "__bool__"
):
return True
return self_method()
def _replicate_for_data_parallel(self):
# we have to initialize ScriptModule properly so that
# it works with pybind11
def init_fn(script_module):
# Don't do anything here, we'll initialize the ScriptModule below
return
return RecursiveScriptModule._construct(
self._c._replicate_for_data_parallel(), init_fn
)
# Need to copy all RecursiveScriptModule methods to ScriptModule.
#
# This is because `super(MyScriptModule, self).foo()` does not use
# `__getattr__` to look up `foo`. So we need to make each method available on
# the ScriptModule manually.
for name, item in RecursiveScriptModule.__dict__.items():
if not callable(item) and not isinstance(item, property):
continue
if name.startswith("__") or hasattr(ScriptModule, name):
continue
# We can copy over the implementation wholesale because besides the
# `super()` thing above, ScriptModule behaves exactly like
# RecursiveScriptModule
setattr(ScriptModule, name, item)
def _get_methods(cls):
import inspect
# In Python 3 unbound methods are functions, but in Python 2 they are methods
return inspect.getmembers(
cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x)
)
_compiled_methods_allowlist = {
"forward",
"register_buffer",
"register_parameter",
"add_module",
"_apply",
"apply",
"cuda",
"cpu",
"to",
"type",
"float",
"double",
"half",
"state_dict",
"_save_to_state_dict",
"load_state_dict",
"_load_from_state_dict",
"_named_members",
"parameters",
"named_parameters",
"buffers",
"named_buffers",
"children",
"named_children",
"modules",
"named_modules",
"zero_grad",
"share_memory",
"_get_name",
"extra_repr",
"_slow_forward",
"_tracing_name",
"eval",
"train",
"get_extra_state",
"set_extra_state"
}
def _make_fail(name):
def fail(self, *args, **kwargs):
raise RuntimeError(name + " is not supported on ScriptModules")
return fail
for name, method in _get_methods(torch.nn.Module):
if name.startswith("__"):
continue
if (
name not in RecursiveScriptModule.__dict__
and name not in _compiled_methods_allowlist
):
setattr(RecursiveScriptModule, method.__name__, _make_fail(name))
else:
# TODO MAKE SURE THAT DISABLING WORKS
class RecursiveScriptClass(object): # type: ignore[no-redef]
def __init__(self):
super().__init__()
[docs] class ScriptModule(torch.nn.Module): # type: ignore[no-redef]
def __init__(self, arg=None):
super().__init__()
class RecursiveScriptModule(ScriptModule): # type: ignore[no-redef]
def __init__(self, arg=None):
super().__init__()
def call_prepare_scriptable_func_impl(obj, memo):
if not isinstance(obj, torch.nn.Module):
return obj
obj_id = id(obj)
# If obj_id is in memo, obj has already been prepared or is being
# prepared in another call up the stack.
if obj_id in memo:
return memo[id(obj)]
obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore[operator]
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
# hierarchy when recursing below.
memo[obj_id] = obj
new_obj_dict = {}
for name, sub_module in obj.__dict__.items():
if name == '_modules':
for k, v in sub_module.items():
sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
new_obj_dict[name] = sub_module
elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
else:
new_obj_dict[name] = sub_module
for k, v in new_obj_dict.items():
obj.__dict__[name] = v
return obj
def call_prepare_scriptable_func(obj):
memo: Dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)
def create_script_dict(obj):
"""
Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
Args:
obj (dict): The Python dictionary that is used to initialize the ``ScriptDict``
returned by this function.
Returns:
An instance of ``torch._C.ScriptDict`` that has the same data as ``obj``
and can be passed between Python and TorchScript with reference semantics and
zero copy overhead.
"""
return torch._C.ScriptDict(obj) # type: ignore[attr-defined]
def create_script_list(obj, type_hint=None):
"""
Create a ``torch._C.ScriptList`` instance with the data from ``obj``.
Args:
obj (dict): The Python list that is used to initialize the ``ScriptList``
returned by this function.
Returns:
An instance of ``torch._C.ScriptList`` that has the same data as ``obj``
and can be passed between Python and TorchScript with reference semantics and
zero copy overhead.
"""
return torch._C.ScriptList(obj) # type: ignore[attr-defined]
[docs]def script(obj, optimize=None, _frames_up=0, _rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
r"""
Scripting a function or ``nn.Module`` will inspect the source code, compile
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
:class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all
features in Python work, but we provide enough functionality to compute on
tensors and do control-dependent operations. For a complete guide, see the
:ref:`language-reference`.
Scripting a dictionary or list copies the data inside it into a TorchScript instance than can be
subsequently passed by reference between Python and TorchScript with zero copy overhead.
``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists
and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions.
Args:
obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, class type,
dictionary, or list to compile.
example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs
to annotate the arguments for a function or ``nn.Module``.
Returns:
If ``obj`` is ``nn.Module``, ``script`` returns
a :class:`ScriptModule` object. The returned :class:`ScriptModule` will
have the same set of sub-modules and parameters as the
original ``nn.Module``. If ``obj`` is a standalone function,
a :class:`ScriptFunction` will be returned. If ``obj`` is a ``dict``, then
``script`` returns an instance of `torch._C.ScriptDict`. If ``obj`` is a ``list``,
then ``script`` returns an instance of `torch._C.ScriptList`.
**Scripting a function**
The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction`
by compiling the body of the function.
Example (scripting a function):
.. testcode::
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
print(type(foo)) # torch.jit.ScriptFunction
# See the compiled graph as Python code
print(foo.code)
# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
.. testoutput::
:hide:
...
****Scripting a function using example_inputs**
Example inputs can be used to annotate a function arguments.
Example (annotating a function before scripting):
.. testcode::
import torch
def test_sum(a, b):
return a + b
# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])
print(type(scripted_fn)) # torch.jit.ScriptFunction
# See the compiled graph as Python code
print(scripted_fn.code)
# Call the function using the TorchScript interpreter
scripted_fn(20, 100)
.. testoutput::
:hide:
...
**Scripting an nn.Module**
Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
features supported in TorchScript, no changes to the original module code should be necessary. ``script``
will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of
the original module.
Example (scripting a simple module with a Parameter):
.. testcode::
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
# This parameter will be copied to the new ScriptModule
self.weight = torch.nn.Parameter(torch.rand(N, M))
# When this submodule is used, it will be compiled
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mv(input)
# This calls the `forward` method of the `nn.Linear` module, which will
# cause the `self.linear` submodule to be compiled to a `ScriptModule` here
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3))
Example (scripting a module with traced submodules):
.. testcode::
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
To compile a method other than ``forward`` (and recursively compile anything it calls), add
the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation
use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`.
Example (an exported and ignored method in a module)::
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
@torch.jit.export
def some_entry_point(self, input):
return input + 10
@torch.jit.ignore
def python_only_fn(self, input):
# This function won't be compiled, so any
# Python APIs can be used
import pdb
pdb.set_trace()
def forward(self, input):
if self.training:
self.python_only_fn(input)
return input * 99
scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
Example ( Annotating forward of nn.Module using example_inputs)::
import torch
import torch.nn as nn
from typing import NamedTuple
class MyModule(NamedTuple):
result: List[int]
class TestNNModule(torch.nn.Module):
def forward(self, a) -> MyModule:
result = MyModule(result=a)
return result
pdt_model = TestNNModule()
# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
# Run the scripted_model with actual inputs
print(scripted_model([20]))
"""
global type_trace_db
if not _enabled:
return obj
if optimize is not None:
warnings.warn(
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
)
# No-op for modules, functions, class instances that are already scripted
if isinstance(obj, RecursiveScriptClass):
return obj
if isinstance(obj, ScriptModule):
return obj
if isinstance(obj, ScriptFunction):
return obj
if example_inputs:
# If MonkeyType is installed, enable profile directed type annotation
# Check if example_inputs are defined and generate call traces
# for the method by running eager mode version of the method with
# the provide example inputs. This logs all the traces in type_trace_db
type_trace_db = JitTypeTraceStore()
if monkeytype_trace:
monkeytype_config = JitTypeTraceConfig(type_trace_db)
with monkeytype_trace(monkeytype_config):
if isinstance(example_inputs, Dict):
# If the obj is an nn.Module or a class, then each method is
# executed with the arguments provided in the example inputs.
# example inputs here will be of type Dict(class.method, (arguments))
# This is used to infer type annotations for those methods
# which are not called directly under the hood of monkeytype.
for module, example_input in example_inputs.items():
for example in example_input:
module(*example)
elif isinstance(example_inputs, List):
for examples in example_inputs:
obj(*examples)
else:
raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
else:
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
if isinstance(obj, torch.nn.Module):
obj = call_prepare_scriptable_func(obj)
return torch.jit._recursive.create_script_module(
obj, torch.jit._recursive.infer_methods_to_compile
)
if isinstance(obj, dict):
return create_script_dict(obj)
if isinstance(obj, list):
return create_script_list(obj)
if inspect.isclass(obj):
qualified_name = _qualified_name(obj)
# If this type is a `nn.Module` subclass, they probably meant to pass
# an instance instead of a Module
if issubclass(obj, torch.nn.Module):
raise RuntimeError(
"Type '{}' cannot be compiled since it inherits"
" from nn.Module,"
" pass an instance instead".format(obj)
)
# Enums are automatically usable in TorchScript, explicitly scripting
# is not necessary, but not harmful either.
if issubclass(obj, enum.Enum):
return obj
if not _is_new_style_class(obj):
raise RuntimeError(
"TorchScript classes must be new-style classes. "
"Please inherit from 'object'."
)
if len(obj.mro()) > 2:
raise RuntimeError(
"TorchScript classes does not support inheritance yet. "
"Please directly inherit from 'object'."
)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
_compile_and_register_class(obj, _rcb, qualified_name)
return obj
elif inspect.isfunction(obj) or inspect.ismethod(obj):
qualified_name = _qualified_name(obj)
# this is a decorated fn, and we need to the underlying fn and its rcb
if hasattr(obj, "__script_if_tracing_wrapper"):
obj = obj.__original_fn
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
_check_directly_compile_overloaded(obj)
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
if maybe_already_compiled_fn:
return maybe_already_compiled_fn
ast = get_jit_def(obj, obj.__name__)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
fn = torch._C._jit_script_compile(
qualified_name, ast, _rcb, get_default_args(obj)
)
# Forward docstrings
fn.__doc__ = obj.__doc__
_set_jit_function_cache(obj, fn)
return fn
else:
return torch.jit._recursive.create_script_class(obj)
# overloads are registered in _jit_internal and compiled here so that _overload
# can be used in nn/functional.py without an import cycle
def _check_overload_defaults(impl_defaults, overload_defaults, loc):
for name, overload_value in overload_defaults.items():
if name not in impl_defaults or impl_defaults[name] != overload_value:
raise torch.jit.frontend.FrontendError(
loc,
"Default parameters on overloads do not affect the runtime so they "
"must equal to the default parameter on the implementation function. Found on "
"parameter {name}".format(name=name),
)
def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl()
overload_signature = torch.jit.annotations.get_signature(
overload_fn, None, None, inspect.ismethod(overload_fn)
)
impl_ast = get_jit_def(impl_fn, impl_fn.__name__)
overload_defaults = get_default_args(overload_fn)
implementation_defaults = get_default_args(impl_fn)
_rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn)
_check_overload_defaults(
implementation_defaults, overload_defaults, overload_decl.range()
)
fn = torch._C._jit_script_compile_overload(
qual_name,
overload_decl,
impl_ast,
_rcb,
implementation_defaults,
overload_signature,
)
return fn
def _get_overloads(obj):
# check for cached compiled fns
existing_compiled_fns = _try_get_jit_cached_overloads(obj)
qual_name = _qualified_name(obj)
uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name)
if uncompiled_overloads is None:
return existing_compiled_fns
if obj in uncompiled_overloads:
raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
'function', obj))
compiled_fns = []
for overload_fn in uncompiled_overloads:
compiled_fns.append(
_compile_function_with_overload(overload_fn, qual_name, obj)
)
if existing_compiled_fns:
compiled_fns = existing_compiled_fns + compiled_fns
# cache compilation, remove information stored to do compilation
_set_jit_overload_cache(obj, compiled_fns)
_jit_internal._clear_fn_overloads(qual_name)
return compiled_fns
def _check_directly_compile_overloaded(obj):
qual_name = _qualified_name(obj)
if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj):
raise RuntimeError(
"Function {} cannot be directly compiled because it"
" is overloaded. It must be used in a context of a function"
" where its inputs can determine which overload to call.".format(qual_name)
)
def interface(obj):
if not inspect.isclass(obj):
raise RuntimeError("interface must be applied to a class")
if not _is_new_style_class(obj):
raise RuntimeError("TorchScript interfaces must inherit from 'object'")
# Expected MRO is:
# User module
# torch.nn.modules.module.Module
# object
is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3
if not is_module_interface and len(obj.mro()) > 2:
raise RuntimeError(
"TorchScript interface does not support inheritance yet. "
"Please directly inherit from 'object' or 'nn.Module'."
)
qualified_name = _qualified_name(obj)
rcb = _jit_internal.createResolutionCallbackFromFrame(1)
# if this type is a `nn.Module` subclass, generate a module interface type
# instead of a class interface type; a module interface type only compiles
# the user provided methods as part of the interface
ast = get_jit_class_def(obj, obj.__name__)
mangled_classname = torch._C._jit_script_interface_compile(
qualified_name, ast, rcb, is_module_interface
)
obj.__torch_script_interface__ = mangled_classname
return obj
def _recursive_compile_class(obj, loc):
_qual_name = _qualified_name(obj)
# We're starting a new compilation, so update the error call stack in
# case it fails
error_stack = torch._C.CallStack(_qual_name, loc)
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
return _compile_and_register_class(obj, rcb, _qual_name)
CompilationUnit = torch._C.CompilationUnit
set_module(CompilationUnit, "torch.jit")
def pad(s: str, padding: int, offset: int = 0, char: str = ' '):
if padding >= len(s):
padding -= len(s)
return ''.join([char for _ in range(padding + offset)]) + s
class _ScriptProfileColumn:
def __init__(self, header: str, alignment: int = 4, offset: int = 0):
self.header = header
self.alignment = alignment
self.offset = offset
self.rows: Dict[int, Any] = {}
def add_row(self, lineno: int, value: Any):
self.rows[lineno] = value
def materialize(self):
max_length = len(self.header)
rows: List[Tuple[int, str]] = []
for (key, value) in self.rows.items():
cell = str(value)
rows.append((key, cell))
max_length = max(len(cell), max_length)
if self.alignment > 0:
padding = max_length + self.alignment
padding -= padding % self.alignment
else:
padding = 0
rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows]
return pad(self.header, padding, self.offset), rows
class _ScriptProfileTable:
def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]):
self.cols = cols
self.source_range = source_range
def dump_string(self):
outputs: List[str] = []
cells: List[Tuple[str, Dict[int, str]]] = []
header_buffer = ''
for col in self.cols:
header, rows = col.materialize()
header_buffer += header
cells.append((header, dict(rows)))
outputs.append(header_buffer)
outputs.append(pad('', len(header_buffer), 0, '='))
for line in self.source_range:
row_buffer = ''
for header, rows in cells:
cell = rows.get(line)
if cell is None:
row_buffer += pad('', len(header))
else:
row_buffer += cell
outputs.append(row_buffer)
return '\n'.join(outputs)
class _ScriptProfile:
def __init__(self):
self.profile = classes.profiling._ScriptProfile()
def enable(self):
self.profile.enable()
def disable(self):
self.profile.disable()
def dump_string(self) -> str:
outputs: List[str] = []
for source_stats in self.profile._dump_stats():
source_ref = source_stats.source()
source_lines = source_ref.text().splitlines()
dedent = min([len(line) - len(line.lstrip(' ')) for line in source_lines])
source_lines = [line[dedent:] for line in source_lines]
start_line = source_ref.starting_lineno()
end_line = start_line + len(source_lines)
source_range = range(start_line, end_line)
lineno = _ScriptProfileColumn("Line #")
hits = _ScriptProfileColumn("Hits")
time_ns = _ScriptProfileColumn("Time (ns)")
line_contents = _ScriptProfileColumn("Line Contents", 0, 1)
stats = source_stats.line_map()
for line in source_range:
lineno.add_row(line, line)
line_contents.add_row(line, source_lines[line - start_line])
stat = stats.get(line)
if stat is not None:
hits.add_row(line, stat.count())
time_ns.add_row(line, stat.duration_ns())
table = _ScriptProfileTable([lineno, hits, time_ns, line_contents], list(source_range))
outputs.append(table.dump_string())
return '\n\n'.join(outputs)
def dump(self):
print(self.dump_string())
def _unwrap_optional(x):
assert x is not None, "Unwrapping null optional"
return x
_register_builtin(_unwrap_optional, "aten::_unwrap_optional")
_register_builtin(_jit_internal.is_scripting, "aten::is_scripting")
_register_builtin(has_torch_function, "aten::has_torch_function")
_register_builtin(has_torch_function_unary, "aten::has_torch_function")
_register_builtin(has_torch_function_variadic, "aten::has_torch_function")