from._opsimportOpOverloadfromtypingimportSetimporttracebackimporttorchimportweakref__all__=['Library','impl','define','fallthrough_kernel']# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid# libraries calling into kernels not intended to be called._impls:Set[str]=set()# prim is reserved by TorchScript interpreter_reserved_namespaces=['prim']
[docs]deffallthrough_kernel():""" A dummy function to pass to ``Library.impl`` in order to register a fallthrough. """raiseNotImplementedError("fallthrough_kernel() should never be called.")
[docs]classLibrary:""" A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key. To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". To create a new library (with name ns) to register new operators, set the kind to "DEF". To create a fragment of a possibly existing library to register operators (and bypass the limitation that there is only one library for a given namespace), set the kind to "FRAGMENT". Args: ns: library name kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT" dispatch_key: PyTorch dispatch key (default: "") """def__init__(self,ns,kind,dispatch_key=""):ifkindnotin('IMPL','DEF','FRAGMENT'):raiseValueError("Unsupported kind: ",kind)ifnsin_reserved_namespacesand(kind=="DEF"orkind=='FRAGMENT'):raiseValueError(ns," is a reserved namespace. Please try creating a library with another name.")frame=traceback.extract_stack(limit=3)[0]filename,lineno=frame.filename,frame.linenoself.m=torch._C._dispatch_library(kind,ns,dispatch_key,filename,lineno)self.ns=nsself._op_impls:Set[str]=set()self.kind=kindself.dispatch_key=dispatch_key# Use a finalizer to setup the "destructor" instead of __del__.# Python __del__ can lead to weird things (globals and locals may already# be gone when __del__ actually gets called!). finalizers help the# situation because it lets us capture references and keeps them aliveweakref.finalize(self,_del_library,_impls,self._op_impls)def__repr__(self):returnf"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
[docs]defdefine(self,schema,alias_analysis=""):r'''Defines a new operator and its semantics in the ns namespace. Args: schema: function schema to define a new operator. alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not ("CONSERVATIVE"). Returns: name of the operator as inferred from the schema. Example:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY) >>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") '''# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid# AliasAnalysis type in C++ifalias_analysisnotin["","FROM_SCHEMA","CONSERVATIVE"]:raiseRuntimeError(f"Invalid alias_analysis type {alias_analysis}")returnself.m.define(schema,alias_analysis)
[docs]defimpl(self,op_name,fn,dispatch_key=''):r'''Registers the function implementation for an operator defined in the library. Args: op_name: operator name (along with the overload) or OpOverload object. fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` to register a fallthrough. dispatch_key: dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with. Example:: >>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU") '''ifnotcallable(fn):raiseTypeError(f"Input function is required to be a callable but found type {type(fn)}")ifdispatch_key=='':dispatch_key=self.dispatch_keyifisinstance(op_name,str):name=op_nameelifisinstance(op_name,OpOverload):name=op_name._schema.nameoverload_name=op_name._schema.overload_nameifoverload_name!='':name=name+'.'+overload_nameelse:raiseRuntimeError("impl should be passed either a name or an OpOverload object as the first argument")key=self.ns+"/"+name.split("::")[-1]+"/"+dispatch_keyifkeyin_impls:# TODO: in future, add more info about where the existing function is registered (this info is# today already returned by the C++ warning when impl is called but we error out before that)raiseRuntimeError("This is not allowed since there's already a kernel registered from python overriding {}""'s behavior for {} dispatch key and {} namespace.".format(name.split("::")[-1],dispatch_key,self.ns))ifdispatch_key=="Meta":dispatcher_op_name=nameif'::'notindispatcher_op_name:dispatcher_op_name=f'{self.ns}::{dispatcher_op_name}'# Internally, we shouldn't be registering meta kernels for any operators that# have CompositeImplicitAutograd kernels.# Instead, we should be letting those decompositions run, and writing meta kernels# only for the base operators.iftorch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name,"CompositeImplicitAutograd"):raiseRuntimeError(f"We should not register a meta kernel directly to the operator '{name}',"" because it has a CompositeImplicitAutograd kernel in core."" Instead we should let the operator decompose, and ensure that we have meta kernels"" for the base ops that it decomposes into.")self.m.impl(name,dispatch_keyifdispatch_key!=""else"CompositeImplicitAutograd",fn)_impls.add(key)self._op_impls.add(key)
def_del_library(captured_impls,op_impls):captured_impls-=op_impls# decorator to register python functions for library ops# Note: this decorator API should remain consistent with `Library.impl` APIdefimpl(lib,name,dispatch_key=""):defwrap(f):lib.impl(name,f,dispatch_key)returnfreturnwrapdefdefine(lib,schema,alias_analysis=""):defwrap(f):name=lib.define(schema,alias_analysis)lib.impl(name,f)returnfreturnwrap
