Program Listing for File custom_class.h¶
↰ Return to documentation for file (torch/custom_class.h
)
#pragma once
#include <ATen/core/builtin_function.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/class_type.h>
#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/stack.h>
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/util/TypeTraits.h>
#include <torch/custom_class_detail.h>
#include <torch/library.h>
#include <functional>
#include <sstream>
namespace torch {
template <class... Types>
detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}
template <typename Func, typename... ParameterTypeList>
struct InitLambda {
Func f;
};
template <typename Func>
decltype(auto) init(Func&& f) {
using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>;
using ParameterTypeList = typename InitTraits::parameter_types;
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
return init;
}
template <class CurClass>
class class_ : public ::torch::detail::class_base {
static_assert(
std::is_base_of_v<CustomClassHolder, CurClass>,
"torch::class_<T> requires T to inherit from CustomClassHolder");
public:
explicit class_(
const std::string& namespaceName,
const std::string& className,
std::string doc_string = "")
: class_base(
namespaceName,
className,
std::move(doc_string),
typeid(c10::intrusive_ptr<CurClass>),
typeid(c10::tagged_capsule<CurClass>)) {}
template <typename... Types>
class_& def(
torch::detail::types<void, Types...>,
std::string doc_string = "",
std::initializer_list<arg> default_args =
{}) { // Used in combination with
// torch::init<...>()
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
};
defineMethod(
"__init__",
std::move(func),
std::move(doc_string),
default_args);
return *this;
}
// Used in combination with torch::init([]lambda(){......})
template <typename Func, typename... ParameterTypes>
class_& def(
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto init_lambda_wrapper = [func = std::move(init.f)](
c10::tagged_capsule<CurClass> self,
ParameterTypes... arg) {
c10::intrusive_ptr<CurClass> classObj =
std::invoke(func, std::forward<ParameterTypes>(arg)...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod(
"__init__",
std::move(init_lambda_wrapper),
std::move(doc_string),
default_args);
return *this;
}
template <typename Func>
class_& def(
std::string name,
Func f,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
defineMethod(
std::move(name),
std::move(wrapped_f),
std::move(doc_string),
default_args);
return *this;
}
template <typename Func>
class_& def_static(std::string name, Func func, std::string doc_string = "") {
auto qualMethodName = qualClassName + "." + name;
auto schema =
c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
auto wrapped_func =
[func = std::move(func)](jit::Stack& stack) mutable -> void {
using RetType =
typename c10::guts::infer_function_traits_t<Func>::return_type;
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_unique<jit::BuiltinOpFunction>(
std::move(qualMethodName),
std::move(schema),
std::move(wrapped_func),
std::move(doc_string));
classTypePtr->addStaticMethod(method.get());
registerCustomClassMethod(std::move(method));
return *this;
}
template <typename GetterFunc, typename SetterFunc>
class_& def_property(
const std::string& name,
GetterFunc getter_func,
SetterFunc setter_func,
std::string doc_string = "") {
torch::jit::Function* getter{};
torch::jit::Function* setter{};
auto wrapped_getter =
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
auto wrapped_setter =
detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func));
setter = defineMethod(name + "_setter", wrapped_setter, doc_string);
classTypePtr->addProperty(name, getter, setter);
return *this;
}
template <typename GetterFunc>
class_& def_property(
const std::string& name,
GetterFunc getter_func,
std::string doc_string = "") {
torch::jit::Function* getter{};
auto wrapped_getter =
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
classTypePtr->addProperty(name, getter, nullptr);
return *this;
}
template <typename T>
class_& def_readwrite(const std::string& name, T CurClass::*field) {
auto getter_func = [field =
field](const c10::intrusive_ptr<CurClass>& self) {
return self.get()->*field;
};
auto setter_func = [field = field](
const c10::intrusive_ptr<CurClass>& self, T value) {
self.get()->*field = value;
};
return def_property(name, getter_func, setter_func);
}
template <typename T>
class_& def_readonly(const std::string& name, T CurClass::*field) {
auto getter_func =
[field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) {
return self.get()->*field;
};
return def_property(name, getter_func);
}
class_& _def_unboxed(
const std::string& name,
std::function<void(jit::Stack&)> func,
c10::FunctionSchema schema,
std::string doc_string = "") {
auto method = std::make_unique<jit::BuiltinOpFunction>(
qualClassName + "." + name,
std::move(schema),
std::move(func),
std::move(doc_string));
classTypePtr->addMethod(method.get());
registerCustomClassMethod(std::move(method));
return *this;
}
template <typename GetStateFn, typename SetStateFn>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
static_assert(
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
"def_pickle() currently only supports lambdas as "
"__getstate__ and __setstate__ arguments.");
def("__getstate__", std::forward<GetStateFn>(get_state));
// __setstate__ needs to be registered with some custom handling:
// We need to wrap the invocation of the user-provided function
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
// and assign it to the `capsule` attribute.
using SetStateTraits =
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
using SetStateArg = typename c10::guts::typelist::head_t<
typename SetStateTraits::parameter_types>;
auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)](
c10::tagged_capsule<CurClass> self,
SetStateArg arg) {
c10::intrusive_ptr<CurClass> classObj =
std::invoke(set_state, std::move(arg));
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod(
"__setstate__",
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
std::move(setstate_wrapper)));
// type validation
auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema();
#ifndef STRIP_ERROR_MESSAGES
auto format_getstate_schema = [&getstate_schema]() {
std::stringstream ss;
ss << getstate_schema;
return ss.str();
};
#endif
TORCH_CHECK(
getstate_schema.arguments().size() == 1,
"__getstate__ should take exactly one argument: self. Got: ",
format_getstate_schema());
auto first_arg_type = getstate_schema.arguments().at(0).type();
TORCH_CHECK(
*first_arg_type == *classTypePtr,
"self argument of __getstate__ must be the custom class type. Got ",
first_arg_type->repr_str());
TORCH_CHECK(
getstate_schema.returns().size() == 1,
"__getstate__ should return exactly one value for serialization. Got: ",
format_getstate_schema());
auto ser_type = getstate_schema.returns().at(0).type();
auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema();
auto arg_type = setstate_schema.arguments().at(1).type();
TORCH_CHECK(
ser_type->isSubtypeOf(*arg_type),
"__getstate__'s return type should be a subtype of "
"input argument of __setstate__. Got ",
ser_type->repr_str(),
" but expected ",
arg_type->repr_str());
return *this;
}
private:
template <typename Func>
torch::jit::Function* defineMethod(
std::string name,
Func func,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto qualMethodName = qualClassName + "." + name;
auto schema =
c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
// If default values are provided for function arguments, there must be
// none (no default values) or default values for all function
// arguments, except for self. This is because argument names are not
// extracted by inferFunctionSchemaSingleReturn, and so there must be a
// torch::arg instance in default_args even for arguments that do not
// have an actual default value provided.
TORCH_CHECK(
default_args.size() == 0 ||
default_args.size() == schema.arguments().size() - 1,
"Default values must be specified for none or all arguments");
// If there are default args, copy the argument names and default values to
// the function schema.
if (default_args.size() > 0) {
schema = withNewArguments(schema, default_args);
}
auto wrapped_func =
[func = std::move(func)](jit::Stack& stack) mutable -> void {
// TODO: we need to figure out how to profile calls to custom functions
// like this! Currently can't do it because the profiler stuff is in
// libtorch and not ATen
using RetType =
typename c10::guts::infer_function_traits_t<Func>::return_type;
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_unique<jit::BuiltinOpFunction>(
qualMethodName,
std::move(schema),
std::move(wrapped_func),
std::move(doc_string));
// Register the method here to keep the Method alive.
// ClassTypes do not hold ownership of their methods (normally it
// those are held by the CompilationUnit), so we need a proxy for
// that behavior here.
auto method_val = method.get();
classTypePtr->addMethod(method_val);
registerCustomClassMethod(std::move(method));
return method_val;
}
};
template <typename CurClass, typename... CtorArgs>
c10::IValue make_custom_class(CtorArgs&&... args) {
auto userClassInstance =
c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
return c10::IValue(std::move(userClassInstance));
}
// Alternative api for creating a torchbind class over torch::class_ this api is
// preffered to prevent size regressions on Edge usecases. Must be used in
// conjunction with TORCH_SELECTIVE_CLASS macro aka
// selective_class<foo>("foo_namespace", TORCH_SELECTIVE_CLASS("foo"))
template <class CurClass>
inline class_<CurClass> selective_class_(
const std::string& namespace_name,
detail::SelectiveStr<true> className) {
auto class_name = std::string(className.operator const char*());
return torch::class_<CurClass>(namespace_name, class_name);
}
template <class CurClass>
inline detail::ClassNotSelected selective_class_(
const std::string&,
detail::SelectiveStr<false>) {
return detail::ClassNotSelected();
}
// jit namespace for backward-compatibility
// We previously defined everything in torch::jit but moved it out to
// better reflect that these features are not limited only to TorchScript
namespace jit {
using ::torch::class_;
using ::torch::getCustomClass;
using ::torch::init;
using ::torch::isCustomClass;
} // namespace jit
template <class CurClass>
inline class_<CurClass> Library::class_(const std::string& className) {
TORCH_CHECK(
kind_ == DEF || kind_ == FRAGMENT,
"class_(\"",
className,
"\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
"All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
"(Error occurred at ",
file_,
":",
line_,
")");
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
return torch::class_<CurClass>(*ns_, className);
}
const std::unordered_set<std::string> getAllCustomClassesNames();
template <class CurClass>
inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) {
auto class_name = std::string(className.operator const char*());
TORCH_CHECK(
kind_ == DEF || kind_ == FRAGMENT,
"class_(\"",
class_name,
"\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
"All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
"(Error occurred at ",
file_,
":",
line_,
")");
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
return torch::class_<CurClass>(*ns_, class_name);
}
template <class CurClass>
inline detail::ClassNotSelected Library::class_(detail::SelectiveStr<false>) {
return detail::ClassNotSelected();
}
} // namespace torch