Shortcuts

Define FORWARD_HAS_DEFAULT_ARGS

Define Documentation

FORWARD_HAS_DEFAULT_ARGS(...)

This macro enables a module with default arguments in its forward method to be used in a Sequential module.

Example usage:

Let’s say we have a module declared like this:

struct MImpl : torch::nn::Module {
 public:
  explicit MImpl(int value_) : value(value_) {}
  torch::Tensor forward(int a, int b = 2, double c = 3.0) {
    return torch::tensor(a + b + c);
  }
 private:
  int value;
};
TORCH_MODULE(M);

If we try to use it in a Sequential module and run forward:

torch::nn::Sequential seq(M(1));
seq->forward(1);

We will receive the following error message:

MImpl's forward() method expects 3 argument(s), but received 1.
If MImpl's forward() method has default arguments, please make sure
the forward() method is declared with a corresponding
`FORWARD_HAS_DEFAULT_ARGS` macro.

The right way to fix this error is to use the FORWARD_HAS_DEFAULT_ARGS macro when declaring the module:

struct MImpl : torch::nn::Module {
 public:
  explicit MImpl(int value_) : value(value_) {}
  torch::Tensor forward(int a, int b = 2, double c = 3.0) {
    return torch::tensor(a + b + c);
  }
 protected:
  /*
  NOTE: looking at the argument list of `forward`:
  `forward(int a, int b = 2, double c = 3.0)`
  we saw the following default arguments:
  ----------------------------------------------------------------
  0-based index of default |         Default value of arg
  arg in forward arg list  |  (wrapped by `torch::nn::AnyValue()`)
  ----------------------------------------------------------------
              1            |       torch::nn::AnyValue(2)
              2            |       torch::nn::AnyValue(3.0)
  ----------------------------------------------------------------
  Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS`
  macro:
  *‍/
  FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2,
  torch::nn::AnyValue(3.0)})
 private:
  int value;
};
TORCH_MODULE(M);
Now, running the following would work:
torch::nn::Sequential seq(M(1));
seq->forward(1);  // This correctly populates the default arguments for
`MImpl::forward`

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources