Define FORWARD_HAS_DEFAULT_ARGS¶
Defined in File common.h
Define Documentation¶
-
FORWARD_HAS_DEFAULT_ARGS(...)¶
This macro enables a module with default arguments in its forward method to be used in a Sequential module.
Example usage:
Let’s say we have a module declared like this:
struct MImpl : torch::nn::Module { public: explicit MImpl(int value_) : value(value_) {} torch::Tensor forward(int a, int b = 2, double c = 3.0) { return torch::tensor(a + b + c); } private: int value; }; TORCH_MODULE(M);
If we try to use it in a Sequential module and run forward:
torch::nn::Sequential seq(M(1)); seq->forward(1);
We will receive the following error message:
MImpl's forward() method expects 3 argument(s), but received 1. If MImpl's forward() method has default arguments, please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.
The right way to fix this error is to use the
FORWARD_HAS_DEFAULT_ARGS
macro when declaring the module:Now, running the following would work:struct MImpl : torch::nn::Module { public: explicit MImpl(int value_) : value(value_) {} torch::Tensor forward(int a, int b = 2, double c = 3.0) { return torch::tensor(a + b + c); } protected: /* NOTE: looking at the argument list of `forward`: `forward(int a, int b = 2, double c = 3.0)` we saw the following default arguments: ---------------------------------------------------------------- 0-based index of default | Default value of arg arg in forward arg list | (wrapped by `torch::nn::AnyValue()`) ---------------------------------------------------------------- 1 | torch::nn::AnyValue(2) 2 | torch::nn::AnyValue(3.0) ---------------------------------------------------------------- Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS` macro: *‍/ FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2, torch::nn::AnyValue(3.0)}) private: int value; }; TORCH_MODULE(M);
torch::nn::Sequential seq(M(1)); seq->forward(1); // This correctly populates the default arguments for `MImpl::forward`