Shortcuts

Program Listing for File python.h

Return to documentation for file (torch/csrc/api/include/torch/python.h)

#pragma once

#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/ordered_dict.h>
#include <torch/types.h>

#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>

#include <iterator>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

namespace torch {
namespace python {
namespace detail {
inline Device py_object_to_device(py::object object) {
  PyObject* obj = object.ptr();
  if (THPDevice_Check(obj)) {
    return reinterpret_cast<THPDevice*>(obj)->device;
  }
  throw TypeError("Expected device");
}

inline Dtype py_object_to_dtype(py::object object) {
  PyObject* obj = object.ptr();
  if (THPDtype_Check(obj)) {
    return reinterpret_cast<THPDtype*>(obj)->scalar_type;
  }
  throw TypeError("Expected dtype");
}

template <typename ModuleType>
using PyModuleClass =
    py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;

template <typename ModuleType>
void bind_cpp_module_wrapper(
    py::module module,
    PyModuleClass<ModuleType> cpp_class,
    const char* name) {
  // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
  // with a dynamically created class below.
  py::object cpp_module =
      py::module::import("torch.nn.cpp").attr("ModuleWrapper");

  // Grab the `type` class which we'll use as a metaclass to create a new class
  // dynamically.
  py::object type_metaclass =
      py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);

  // The `ModuleWrapper` constructor copies all functions to its own `__dict__`
  // in its constructor, but we do need to give our dynamic class a constructor.
  // Inside, we construct an instance of the original C++ module we're binding
  // (the `torch::nn::Module` subclass), and then forward it to the
  // `ModuleWrapper` constructor.
  py::dict attributes;

  // `type()` always needs a `str`, but pybind11's `str()` method always creates
  // a `unicode` object.
  py::object name_str = py::str(name);

  // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
  // `torch.nn.Module`, and will delegate all calls to the C++ module we're
  // binding.
  py::object wrapper_class =
      type_metaclass(name_str, py::make_tuple(cpp_module), attributes);

  // The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
  // which replaces its methods with those of the C++ module.
  wrapper_class.attr("__init__") = py::cpp_function(
      [cpp_module, cpp_class](
          py::object self, py::args args, py::kwargs kwargs) {
        cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
      },
      py::is_method(wrapper_class));

  // Calling `my_module.my_class` now means that `my_class` is a subclass of
  // `ModuleWrapper`, and whose methods call into the C++ module we're binding.
  module.attr(name) = wrapper_class;
}
} // namespace detail

template <typename ModuleType, typename... Extra>
py::class_<ModuleType, Extra...> add_module_bindings(
    py::class_<ModuleType, Extra...> module) {
  // clang-format off
  return module
      .def("train",
          [](ModuleType& module, bool mode) { module.train(mode); },
          py::arg("mode") = true)
      .def("eval", [](ModuleType& module) { module.eval(); })
      .def("clone", [](ModuleType& module) { return module.clone(); })
      .def_property_readonly(
          "training", [](ModuleType& module) { return module.is_training(); })
      .def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
      .def_property_readonly( "_parameters", [](ModuleType& module) {
            return module.named_parameters(/*recurse=*/false);
          })
      .def("parameters", [](ModuleType& module, bool recurse) {
            return module.parameters(recurse);
          },
          py::arg("recurse") = true)
      .def("named_parameters", [](ModuleType& module, bool recurse) {
            return module.named_parameters(recurse);
          },
          py::arg("recurse") = true)
      .def_property_readonly("_buffers", [](ModuleType& module) {
            return module.named_buffers(/*recurse=*/false);
          })
      .def("buffers", [](ModuleType& module, bool recurse) {
            return module.buffers(recurse); },
          py::arg("recurse") = true)
      .def("named_buffers", [](ModuleType& module, bool recurse) {
            return module.named_buffers(recurse);
          },
          py::arg("recurse") = true)
      .def_property_readonly(
        "_modules", [](ModuleType& module) { return module.named_children(); })
      .def("modules", [](ModuleType& module) { return module.modules(); })
      .def("named_modules",
           [](ModuleType& module, py::object /* unused */, std::string prefix, bool remove_duplicate /* unused */) {
            return module.named_modules(std::move(prefix));
          },
          py::arg("memo") = py::none(),
          py::arg("prefix") = std::string(),
          py::arg("remove_duplicate") = true)
      .def("children", [](ModuleType& module) { return module.children(); })
      .def("named_children",
          [](ModuleType& module) { return module.named_children(); })
      .def("to", [](ModuleType& module, py::object object, bool non_blocking) {
            if (THPDevice_Check(object.ptr())) {
              module.to(
                  reinterpret_cast<THPDevice*>(object.ptr())->device,
                  non_blocking);
            } else {
              module.to(detail::py_object_to_dtype(object), non_blocking);
            }
          },
          py::arg("dtype_or_device"),
          py::arg("non_blocking") = false)
      .def("to",
          [](ModuleType& module,
             py::object device,
             py::object dtype,
             bool non_blocking) {
              if (device.is_none()) {
                module.to(detail::py_object_to_dtype(dtype), non_blocking);
              } else if (dtype.is_none()) {
                module.to(detail::py_object_to_device(device), non_blocking);
              } else {
                module.to(
                    detail::py_object_to_device(device),
                    detail::py_object_to_dtype(dtype),
                    non_blocking);
              }
          },
          py::arg("device"),
          py::arg("dtype"),
          py::arg("non_blocking") = false)
      .def("cuda", [](ModuleType& module) { module.to(kCUDA); })
      .def("cpu", [](ModuleType& module) { module.to(kCPU); })
      .def("float", [](ModuleType& module) { module.to(kFloat32); })
      .def("double", [](ModuleType& module) { module.to(kFloat64); })
      .def("half", [](ModuleType& module) { module.to(kFloat16); })
      .def("__str__", [](ModuleType& module) { return module.name(); })
      .def("__repr__", [](ModuleType& module) { return module.name(); });
  // clang-format on
}

template <typename ModuleType, bool force_enable = false>
torch::disable_if_t<
    torch::detail::has_forward<ModuleType>::value && !force_enable,
    detail::PyModuleClass<ModuleType>>
bind_module(py::module module, const char* name) {
  py::module cpp = module.def_submodule("cpp");
  auto cpp_class =
      add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
  detail::bind_cpp_module_wrapper(module, cpp_class, name);
  return cpp_class;
}

template <
    typename ModuleType,
    typename =
        torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
detail::PyModuleClass<ModuleType> bind_module(
    py::module module,
    const char* name) {
  return bind_module<ModuleType, /*force_enable=*/true>(module, name)
      .def("forward", &ModuleType::forward)
      .def("__call__", &ModuleType::forward);
}
} // namespace python
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources