
Program Listing for File library.h

Return to documentation for file (torch/library.h)

  // m is a torch::Library; methods on it will define
  // operators in the myops namespace
  m.def("add", add_impl);
  // m is a torch::Library; methods on it will define
  // CPU implementations of operators in the myops namespace.
  // It is NOT valid to call torch::Library::def()
  // in this context.
  m.impl("add", add_cpu_impl);
// You must define all of the operators for this library in
// this namespace.
TORCH_LIBRARY(myops, m) {
  // Define a operator with exactly one implementation for all backends.
  m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);

  // Define a schema for an operator, but provide no implementation
  // (use this syntax if you want to use the dispatcher)
  m.def("mul(Tensor self, Tensor other) -> Tensor");

  // Provide an implementation for a defined operator (you can
  // provide multiple; one per backend).  The dispatcher takes care of
  // calling the correct implementation depending on if we get a CPU
  // tensor or a CUDA tensor
  m.impl("mul", torch::kCPU, &mul_cpu_impl);
  m.impl("mul", torch::kCUDA, &mul_cuda_impl);

// Define implementations for operators for a non-standard backend,
// e.g., XLA (valid values are entries of DispatchKey).  This can
// be used to define operators in a different file than the initial
// TORCH_LIBRARY definition (e.g., if it is in an external library)
  m.impl("mul", &mul_xla_impl);
#pragma once

#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/op_registration/op_allowlist.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/DispatchKey.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>

// Just for inferFunctionSchemaFromFunctor
#include <ATen/core/enum_tag.h>
#include <ATen/core/op_registration/op_registration.h>

namespace torch {

#if defined C10_MOBILE
struct NoInferSchemaTag {};


// For multipy/torchdeploy use case
enum class _RegisterOrVerify { REGISTER, VERIFY };

template <class CurClass>
class class_;


class TORCH_API CppFunction final {
  // TODO: This is morally the same thing as KernelRegistrationConfig, but it's
  // opaque to the user.

  template <typename Func>
  explicit CppFunction(
      Func* f,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
        debug_() {}

  template <typename FuncPtr>
  explicit CppFunction(
      FuncPtr f,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
            c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
                typename FuncPtr::FuncType>()),
        debug_() {}

  template <typename Lambda>
  explicit CppFunction(
      Lambda&& f,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedLambda(
        debug_() {}

#if defined C10_MOBILE
  template <typename Func>
  explicit CppFunction(
      Func* f,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
        // TODO: Don't go through WrapRuntimeKernelFunctor
        debug_() {}

  template <typename FuncPtr>
  explicit CppFunction(
      FuncPtr f,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
            c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
        // TODO: Don't go through WrapRuntimeKernelFunctor
        debug_() {}

  template <typename Lambda>
  explicit CppFunction(
      Lambda&& f,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedLambda(
        // TODO: Don't go through WrapRuntimeKernelFunctor
        debug_() {}


  CppFunction(CppFunction&&) noexcept = default;

  CppFunction& operator=(CppFunction&&) = default;

  static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) {
    return CppFunction(
        /* cpp_signature */ std::nullopt, // not known for boxed functions
        /* schema */ nullptr);

  static CppFunction makeFallthrough() {
    return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough());

  static CppFunction makeNamedNotSupported() {
    return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported());

  template <c10::BoxedKernel::BoxedKernelFunction* func>
  static CppFunction makeFromBoxedFunction() {
    return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction<func>());

  // Variant that takes in a boxed kernel function with a plumbed
  // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for
  // details.
  template <c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
  static CppFunction makeFromBoxedFunction() {
    return makeFromBoxedKernel(c10::BoxedKernel::makeFromFunction<func>());

  template <class KernelFunctor>
  static CppFunction makeFromBoxedFunctor(
      std::unique_ptr<KernelFunctor> kernelFunctor) {
    return makeFromBoxedKernel(

  template <
      typename FuncPtr,
          std::nullptr_t> = nullptr>
  static CppFunction makeFromUnboxedFunction(FuncPtr* f) {
    return CppFunction(f);

  template <
      typename FuncPtr,
          std::nullptr_t> = nullptr>
  static CppFunction makeFromUnboxedFunction(FuncPtr f) {
    return CppFunction(f);

  CppFunction&& debug(std::string d) && {
    debug_ = std::move(d);
    return std::move(*this);

  std::optional<c10::DispatchKey> dispatch_key_;
  c10::KernelFunction func_;
  std::optional<c10::impl::CppSignature> cpp_signature_;
  std::unique_ptr<c10::FunctionSchema> schema_;
  std::string debug_;

  // The "setter" for dispatch_key_
  template <typename Func>
  friend CppFunction dispatch(c10::DispatchKey, Func&&);

  // The only class which actually pulls out values from CppFunction (does so
  // destructively, felt too lazy to write accessors that I don't even
  // want users to use)
  friend class Library;

      c10::KernelFunction func,
      std::optional<c10::impl::CppSignature> cpp_signature,
      std::unique_ptr<c10::FunctionSchema> schema);

template <typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
  CppFunction f(std::forward<Func>(raw_f));
  if (k == c10::DispatchKey::CatchAll) {
    f.dispatch_key_ = std::nullopt;
  } else {
    f.dispatch_key_ = k;
  return f;

template <typename Func>
inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
  auto deviceTypeToDispatchKey = [](c10::DeviceType t) {
    switch (t) {
      // This list is synchronized with the k-constants in c10/core/DeviceType.h
      case c10::DeviceType::CPU:
        return c10::DispatchKey::CPU;
      case c10::DeviceType::CUDA:
        return c10::DispatchKey::CUDA;
      case c10::DeviceType::IPU:
        return c10::DispatchKey::IPU;
      case c10::DeviceType::XLA:
        return c10::DispatchKey::XLA;
      case c10::DeviceType::Lazy:
        return c10::DispatchKey::Lazy;
      case c10::DeviceType::XPU:
        return c10::DispatchKey::XPU;
      case c10::DeviceType::MPS:
        return c10::DispatchKey::MPS;
      case c10::DeviceType::Meta:
        return c10::DispatchKey::Meta;
      case c10::DeviceType::HIP:
        return c10::DispatchKey::HIP;
      case c10::DeviceType::MAIA:
        return c10::DispatchKey::MAIA;
      case c10::DeviceType::HPU:
        return c10::DispatchKey::HPU;
      case c10::DeviceType::MTIA:
        return c10::DispatchKey::MTIA;
      case c10::DeviceType::PrivateUse1:
        return c10::DispatchKey::PrivateUse1;
            "Device type ",
            " cannot be overloaded at dispatch time, "
            "please file a bug report explaining what you were trying to do.");
  return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));

inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k, bool allow_typevars=false) {
  c10::FunctionSchema s = torch::jit::parseSchema(str, /*allow_typevars*/allow_typevars);
  return s;

inline c10::FunctionSchema schema(const char* s, bool allow_typevars=false) {
  return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA, allow_typevars);

inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) {
  return std::move(s);

namespace detail {

inline std::variant<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
    c10::FunctionSchema&& s) {
  return std::move(s);
inline std::variant<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
    c10::OperatorName&& n) {
  return std::move(n);
inline std::variant<c10::OperatorName, c10::FunctionSchema>
constructSchemaOrName(const char* str) {
  auto s = torch::jit::parseSchemaOrName(str);
  if (std::holds_alternative<c10::FunctionSchema>(s)) {
  return s;

class TorchLibraryInit;

} // namespace detail

// Note [Selective build]
// ~~~~~~~~~~~~~~~~~~~~~~
// In some settings, especially mobile, it is important to avoid compiling any
// references to functions that you aren't actually going to use, so that they
// can be eliminated by the linker.  We call this capability "selective build".
// A very easy way to implement selective build which results in a lot of
// boilerplate is to just add ifdef's around every registration call, but this
// means you have to write a lot of extra lines of code at every registration
// site, and it also means you have to define some munging scheme to map
// operators to macros.
// Instead of doing this, we have a different mechanism centered around the
// concept of a SelectiveStr.  A selective name is like a const char* string,
// except it also carries at compile time a boolean saying whether or not a
// registration should actually happen or not.  We then have extra overloads
// which bypass registration entirely if a selective name is disabled.  We do a
// constexpr test to see if a operator should be enabled or not; this is
// currently implemented in ATen/core/op_registration/op_allowlist.h

namespace detail {

// dummy class for non selected custom torchbind classes
class ClassNotSelected {
  ClassNotSelected& def_pickle(...) {
    return *this;
  ClassNotSelected& def(...) {
    return *this;

// A SelectiveStr is like a const char*, except that it also comes
// with a type brand that says whether or not the name is enabled or
// not.  If the string is disabled, then (at compile time) we DON'T generate
// a registration call for it.  This class is not intended to be called
// directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below
// to create it.
template <bool enabled>
class SelectiveStr {
  constexpr explicit SelectiveStr(const char* name) : name_(name) {}
  constexpr operator const char*() {
    return name_;

  const char* name_;


} // namespace detail

class TORCH_API Library final {
  enum Kind {
    DEF, // from TORCH_LIBRARY (no qualifier)

      Kind kind,
      std::string ns,
      std::optional<c10::DispatchKey> k,
      const char* file,
      uint32_t line);

  Library(const Library&) = delete;
  Library& operator=(const Library&) = delete;
  Library(Library&&) = default;
  Library& operator=(Library&&) = default;

  // Some notes about the API design here.  We had the following constraints:
  //  - We need to support multiple "types" of arguments for schema and
  //    functions (e.g., unnamed lambda types, regular functions, const char*,
  //    fully instantiated schemas)
  //  - We don't want to write exponentially many overloads
  //  - We don't want to rely on implicit conversion to a common type,
  //    because the C++ compiler will only be willing to do a single
  //    implicit conversion (reducing the set of valid types which you
  //    can invoke with); also error messages are worse when an implicit
  //    conversion is not selected (as the compiler will not explain
  //    why it didn't select an implicit conversion; this is different
  //    from overloads where it will explain each candidate overload and
  //    why it didn't apply)
  // To solve all of these constraints at the same time, we use a trick taken
  // from the pybind11 library: template over the argument in the user visible
  // API, and inside of the templated function explicitly call an overloaded
  // function to resolve the argument to a real type.  You get the good error
  // messages from overloads, but at the same time you only need to write the
  // overload for any given argument type once.

  template <typename Schema>
  Library& def(
      Schema&& raw_schema,
      const std::vector<at::Tag>& tags = {},
      _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
    c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
    return _def(std::move(s), nullptr, tags, rv);

  Library& set_python_module(const char* pymodule, const char* context = "") {
    python_module_ = {pymodule, context};
    return *this;

  Library& impl_abstract_pystub(const char* pymodule, const char* context = "") {
    return set_python_module(pymodule, context);

  template <typename NameOrSchema, typename Func>
  Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f,
      const std::vector<at::Tag>& tags = {}) & {
    CppFunction f(std::forward<Func>(raw_f));
    return _def(
        ::std::move(f), tags);

  template <typename Name, typename Func>
  Library& impl(
      Name name,
      Func&& raw_f,
      _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
    // TODO: need to raise an error when you impl a function that has a
    // catch all def
#if defined C10_MOBILE
    CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
    CppFunction f(std::forward<Func>(raw_f));
    return _impl(name, std::move(f), rv);

#if defined C10_MOBILE
  // Note: This overload is needed only for C10_MOBILE, since the automatically
  // defined copy constructor for the CppFunction doesn't have the additional
  // NoInferSchemaTag argument. We define the overload for the impl() function
  // to accept a CppFunction&& argument. The already constructed CppFunction
  // object may or may not have the inferred schema, but it doesn't matter
  // for our purposes since if it already has the inferred schema, then we
  // might as well just pass it through directly.
  template <typename Name>
  Library& impl(Name name, CppFunction&& raw_f) & {
    // TODO: need to raise an error when you impl a function that has a
    // catch all def
    CppFunction f(std::forward<CppFunction>(raw_f));
    return _impl(name, std::move(f));

  // Helper for getting an OperatorName for a const char*.  You probably
  // don't need this.
  c10::OperatorName _resolve(const char* name) const;

  template <typename Name, typename Dispatch, typename Func>
  Library& impl(Name name, Dispatch&& key, Func&& raw_f) & {
    return impl(
        name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));

  template <typename Name, typename Func>
  Library& impl_UNBOXED(Name /*name*/, Func* /*raw_f*/) & {
        ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
    return *this;

  // These overloads cover cases when a SelectiveStr (see Note [Selective
  // build]) has been disabled at compile time.  In that case, don't generate
  // any code referencing the passed in functions at all.
  Library& def(detail::SelectiveStr<false>, const std::vector<at::Tag>& tags [[maybe_unused]] = {}) & {
    return *this;
  Library& def(detail::SelectiveStr<true> raw_schema, const std::vector<at::Tag>& tags = {}) & {
    return def(raw_schema.operator const char*(), tags);
  template <typename Func>
  Library& def(detail::SelectiveStr<false>, Func&& /*raw_f*/, const std::vector<at::Tag>& tags [[maybe_unused]] = {}) & {
    return *this;
  template <typename Func>
  Library& def(detail::SelectiveStr<true> raw_name_or_schema, Func&& raw_f, const std::vector<at::Tag>& tags = {}) & {
    return def(
        raw_name_or_schema.operator const char*(), std::forward<Func>(raw_f), tags);

  template <typename Func>
  Library& impl(detail::SelectiveStr<false>, Func&& /*raw_f*/) & {
    return *this;
  template <typename Dispatch, typename Func>
  Library& impl(
      Dispatch&& /*key*/,
      Func&& /*raw_f*/) & {
    return *this;
  template <typename Func>
  Library& impl_UNBOXED(
      detail::SelectiveStr<false> /*name*/,
      Func* /*raw_f*/) & {
        ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
    return *this;

  template <typename Func>
  Library& impl(detail::SelectiveStr<true> name, Func&& raw_f) & {
    return impl(name.operator const char*(), std::forward<Func>(raw_f));
  template <typename Dispatch, typename Func>
  Library& impl(
      detail::SelectiveStr<true> name,
      Dispatch&& key,
      Func&& raw_f) & {
    return impl(
        name.operator const char*(),
  template <typename Func>
  Library& impl_UNBOXED(
      detail::SelectiveStr<true> /*name*/,
      Func* /*raw_f*/) & {
        ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
    return *this;

  template <typename Func>
  Library& fallback(Func&& raw_f) & {
    CppFunction f((std::forward<Func>(raw_f)));
    return _fallback(std::move(f));

  template <class CurClass>
  inline torch::class_<CurClass> class_(const std::string& className);

  // These overloads enable the use of selective build on classes registered
  // within a library. The API is the same as before with 1 minor change.
  // Instead of m.class_<foo>("foo") you instead do
  // m.class_<foo>(TORCH_SELECTIVE_CLASS("foo"))
  template <class CurClass>
  inline torch::class_<CurClass> class_(detail::SelectiveStr<true> className);

  template <class CurClass>
  inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className);

  // De-registers all registrations created with this Library
  void reset();

  Kind kind_;
  std::optional<std::string> ns_;
  std::optional<c10::DispatchKey> dispatch_key_;
  std::optional<std::pair<const char*, const char*>> python_module_;
  const char* file_;
  uint32_t line_;

  std::vector<c10::RegistrationHandleRAII> registrars_;

  friend class detail::TorchLibraryInit;

  // Non-user visible actual implementations of functions.  These aren't
  // public because we only implement & qualifier and not && qualifier
  Library& _def(
      c10::FunctionSchema&& schema,
      c10::OperatorName* out_name = nullptr,
      const std::vector<at::Tag>& tags = {},
      _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
  Library& _def(
      std::variant<c10::OperatorName, c10::FunctionSchema>&&,
      CppFunction&& f,
      const std::vector<at::Tag>& tags = {}) &;
  Library& _impl(
      const char* name,
      CppFunction&& f,
      _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
  Library& _fallback(CppFunction&& f) &;

  at::OperatorName _parseNameForLib(const char* name_str) const;

namespace detail {

class TorchLibraryInit final {
  using InitFn = void(Library&);
  Library lib_;

      Library::Kind kind,
      InitFn* fn,
      const char* ns,
      std::optional<c10::DispatchKey> k,
      const char* file,
      uint32_t line)
      : lib_(kind, ns, k, file, line) {

} // namespace detail

} // namespace torch

// NB: The EXACT NAMING of the initializer functions (e.g.,
// TORCH_LIBRARY_init_aten) matters for the code analyzer;
// see the regexes at tools/code_analyzer/

#define TORCH_LIBRARY(ns, m)                                                   \
  static void TORCH_LIBRARY_init_##ns(torch::Library&);                        \
  static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \
      torch::Library::DEF,                                                     \
      &TORCH_LIBRARY_init_##ns,                                                \
      #ns,                                                                     \
      std::nullopt,                                                            \
      __FILE__,                                                                \
      __LINE__);                                                               \
  void TORCH_LIBRARY_init_##ns(torch::Library& m)


#define _TORCH_LIBRARY_FRAGMENT(ns, m, uid)                       \
  static void C10_CONCATENATE(                                    \
      TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library&); \
  static const torch::detail::TorchLibraryInit C10_CONCATENATE(   \
      TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)(           \
      torch::Library::FRAGMENT,                                   \
      &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
      #ns,                                                        \
      std::nullopt,                                               \
      __FILE__,                                                   \
      __LINE__);                                                  \
  void C10_CONCATENATE(                                           \
      TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library & m)

// NB: if the dispatch key is not whitelisted, we simply omit the Library
// call entirely
#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)

#define _TORCH_LIBRARY_IMPL(ns, k, m, uid)                                \
  static void C10_CONCATENATE(                                            \
      TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);       \
  static const torch::detail::TorchLibraryInit C10_CONCATENATE(           \
      TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(                 \
      torch::Library::IMPL,                                               \
      (c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::k)       \
           ? &C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid) \
           : [](torch::Library&) -> void {}),                             \
      #ns,                                                                \
      std::make_optional(c10::DispatchKey::k),                            \
      __FILE__,                                                           \
      __LINE__);                                                          \
  void C10_CONCATENATE(                                                   \
      TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)

// These are variants of the macros above which are to be used for testing (they
// don't setup the static initializer, so you can control the visibility of
// the allocated library yourself).
// DO NOT use these in production code, they are NOT understood by the
// code analyzer and will be incorrectly analyzed in those situations.

#define MAKE_TORCH_LIBRARY(ns) \
  torch::Library(torch::Library::DEF, #ns, std::nullopt, __FILE__, __LINE__)
#define MAKE_TORCH_LIBRARY_IMPL(ns, k)         \
  torch::Library(                              \
      torch::Library::IMPL,                    \
      #ns,                                     \
      std::make_optional(c10::DispatchKey::k), \
      __FILE__,                                \

// Make the custom class API visible, so it is available from
// torch::Library.

#include <torch/custom_class.h>


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources