Shortcuts

Program Listing for File common.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/common.h)

 public:
  explicit MImpl(int value_) : value(value_) {}
  torch::Tensor forward(int a, int b = 2, double c = 3.0) {
    return torch::tensor(a + b + c);
  }
 private:
  int value;
};
TORCH_MODULE(M);
If MImpl's forward() method has default arguments, please make sure
the forward() method is declared with a corresponding
`FORWARD_HAS_DEFAULT_ARGS` macro.
#pragma once

#include <c10/util/irange.h>

#define FORWARD_HAS_DEFAULT_ARGS(...) \
  template <typename ModuleType, typename... ArgumentTypes> \
  friend struct torch::nn::AnyModuleHolder; \
  bool _forward_has_default_args() override { \
    return true; \
  } \
  unsigned int _forward_num_required_args() override { \
    std::vector<std::pair<unsigned int, torch::nn::AnyValue>> args_info = {__VA_ARGS__}; \
    return args_info[0].first; \
  } \
  std::vector<torch::nn::AnyValue> _forward_populate_default_args(std::vector<torch::nn::AnyValue>&& arguments) override { \
    std::vector<std::pair<unsigned int, torch::nn::AnyValue>> args_info = {__VA_ARGS__}; \
    unsigned int num_all_args = args_info[args_info.size() - 1].first + 1; \
    TORCH_INTERNAL_ASSERT(arguments.size() >= _forward_num_required_args() && arguments.size() <= num_all_args); \
    std::vector<torch::nn::AnyValue> ret; \
    ret.reserve(num_all_args); \
    for (const auto i : c10::irange(arguments.size())) { \
      ret.emplace_back(std::move(arguments[i])); \
    } \
    for (auto& arg_info : args_info) { \
      if (arg_info.first > ret.size() - 1) ret.emplace_back(std::move(arg_info.second)); \
    } \
    return ret; \
  }

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources