from._opsimportOpOverloadfromtypingimportSetimporttracebackimporttorch__all__=['Library','impl','define']# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid# libraries calling into kernels not intended to be called._impls:Set[str]=set()
[docs]classLibrary:""" A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key. Args: ns: library name kind: "DEF", "IMPL" (default: "IMPL") dispatch_key: PyTorch dispatch key (default: "") """def__init__(self,ns,kind,dispatch_key=""):ifkind!="IMPL"andkind!="DEF":raiseValueError("Unsupported kind: ",kind)frame=traceback.extract_stack(limit=3)[0]filename,lineno=frame.filename,frame.linenoself.m=torch._C._dispatch_library(kind,ns,dispatch_key,filename,lineno)self.ns=nsself._op_impls=set()self.kind=kindself.dispatch_key=dispatch_keydef__repr__(self):return"Library(kind={}, ns={}, dispatch_key={})>".format(self.kind,self.ns,self.dispatch_key)defimpl(self,op_name,fn,dispatch_key=''):ifdispatch_key=='':dispatch_key=self.dispatch_keyifisinstance(op_name,str):name=op_nameelifisinstance(op_name,OpOverload):name=op_name._schema.nameoverload_name=op_name._schema.overload_nameifoverload_name!='':name=name+'.'+overload_nameelse:raiseRuntimeError("impl should be passed either a name or an OpOverload object as the first argument")key=self.ns+"/"+name.split("::")[-1]+"/"+dispatch_keyifkeyin_impls:# TODO: in future, add more info about where the existing function is registered (this info is# today already returned by the C++ warning when impl is called but we error out before that)raiseRuntimeError("This is not allowed since there's already a kernel registered from python overriding {}""'s behavior for {} dispatch key and {} namespace.".format(name.split("::")[-1],dispatch_key,self.ns))self.m.impl(name,dispatch_key,fn)_impls.add(key)self._op_impls.add(key)
[docs]defdefine(self,schema,alias_analysis=""):''' Takes a schema to define a new operator. Also, optionally takes `alias_analysis` argument to indicate if the aliasing properties of the arguments can be inferred from the schema (default behavior) or not ("CONSERVATIVE"). Returns the name of the operator as inferred from the schema. '''# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid# AliasAnalysis type in C++ifalias_analysisnotin["","FROM_SCHEMA","CONSERVATIVE"]:raiseRuntimeError("Invalid alias_analysis type")returnself.m.define(schema,alias_analysis)
# decorator to register python functions for library ops# Note: this decorator API should remain consistent with `Library.impl` APIdefimpl(lib,name,dispatch_key=""):defwrap(f):lib.impl(name,f,dispatch_key)returnwrapdefdefine(lib,schema,alias_analysis=""):defwrap(f):name=lib.define(schema,alias_analysis)lib.impl(name,f)returnwrap
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.