from._opsimportOpOverloadfromtypingimportSetimporttracebackimporttorchimportweakref__all__=['Library','impl','define','fallthrough_kernel']# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid# libraries calling into kernels not intended to be called._impls:Set[str]=set()# prim is reserved by TorchScript interpreter_reserved_namespaces=['prim']
[docs]deffallthrough_kernel():""" A dummy function to pass to ``Library.impl`` in order to register a fallthrough. """raiseNotImplementedError("fallthrough_kernel() should never be called.")
[docs]classLibrary:""" A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key. To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". To create a new library (with name ns) to register new operators, set the kind to "DEF". To create a fragment of a possibly existing library to register operators (and bypass the limitation that there is only one library for a given namespace), set the kind to "FRAGMENT". Args: ns: library name kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT" dispatch_key: PyTorch dispatch key (default: "") """def__init__(self,ns,kind,dispatch_key=""):ifkindnotin('IMPL','DEF','FRAGMENT'):raiseValueError("Unsupported kind: ",kind)ifnsin_reserved_namespacesand(kind=="DEF"orkind=='FRAGMENT'):raiseValueError(ns," is a reserved namespace. Please try creating a library with another name.")frame=traceback.extract_stack(limit=3)[0]filename,lineno=frame.filename,frame.linenoself.m=torch._C._dispatch_library(kind,ns,dispatch_key,filename,lineno)self.ns=nsself._op_impls:Set[str]=set()self.kind=kindself.dispatch_key=dispatch_key# Use a finalizer to setup the "destructor" instead of __del__.# Python __del__ can lead to weird things (globals and locals may already# be gone when __del__ actually gets called!). finalizers help the# situation because it lets us capture references and keeps them aliveweakref.finalize(self,_del_library,_impls,self._op_impls)def__repr__(self):returnf"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
[docs]defdefine(self,schema,alias_analysis=""):r'''Defines a new operator and its semantics in the ns namespace. Args: schema: function schema to define a new operator. alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not ("CONSERVATIVE"). Returns: name of the operator as inferred from the schema. Example:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY) >>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") '''# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid# AliasAnalysis type in C++ifalias_analysisnotin["","FROM_SCHEMA","CONSERVATIVE"]:raiseRuntimeError(f"Invalid alias_analysis type {alias_analysis}")returnself.m.define(schema,alias_analysis)
[docs]defimpl(self,op_name,fn,dispatch_key=''):r'''Registers the function implementation for an operator defined in the library. Args: op_name: operator name (along with the overload) or OpOverload object. fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` to register a fallthrough. dispatch_key: dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with. Example:: >>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU") '''ifnotcallable(fn):raiseTypeError(f"Input function is required to be a callable but found type {type(fn)}")ifdispatch_key=='':dispatch_key=self.dispatch_keyifisinstance(op_name,str):name=op_nameelifisinstance(op_name,OpOverload):name=op_name._schema.nameoverload_name=op_name._schema.overload_nameifoverload_name!='':name=name+'.'+overload_nameelse:raiseRuntimeError("impl should be passed either a name or an OpOverload object as the first argument")key=self.ns+"/"+name.split("::")[-1]+"/"+dispatch_keyifkeyin_impls:# TODO: in future, add more info about where the existing function is registered (this info is# today already returned by the C++ warning when impl is called but we error out before that)raiseRuntimeError("This is not allowed since there's already a kernel registered from python overriding {}""'s behavior for {} dispatch key and {} namespace.".format(name.split("::")[-1],dispatch_key,self.ns))ifdispatch_key=="Meta":dispatcher_op_name=nameif'::'notindispatcher_op_name:dispatcher_op_name=f'{self.ns}::{dispatcher_op_name}'# Internally, we shouldn't be registering meta kernels for any operators that# have CompositeImplicitAutograd kernels.# Instead, we should be letting those decompositions run, and writing meta kernels# only for the base operators.iftorch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name,"CompositeImplicitAutograd"):raiseRuntimeError(f"We should not register a meta kernel directly to the operator '{name}',"" because it has a CompositeImplicitAutograd kernel in core."" Instead we should let the operator decompose, and ensure that we have meta kernels"" for the base ops that it decomposes into.")self.m.impl(name,dispatch_keyifdispatch_key!=""else"CompositeImplicitAutograd",fn)_impls.add(key)self._op_impls.add(key)
def_del_library(captured_impls,op_impls):captured_impls-=op_impls# decorator to register python functions for library ops# Note: this decorator API should remain consistent with `Library.impl` APIdefimpl(lib,name,dispatch_key=""):defwrap(f):lib.impl(name,f,dispatch_key)returnfreturnwrapdefdefine(lib,schema,alias_analysis=""):defwrap(f):name=lib.define(schema,alias_analysis)lib.impl(name,f)returnfreturnwrap
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.