Shortcuts

Template Struct Function

Page Contents

Struct Documentation

template<class T>
struct torch::autograd::Function

To use custom autograd operations, implement a Function subclass with static forward and backward functions:

forward can take as many arguments as you want and should return either a variable list or a Variable. Use of any direct Variable arguments will be registered in the graph but no vectors/sets or any other data structures will be traversed. You can use c10::optional<Tensor> as one of the arguments and it will be registered as a variable in the graph if the argument has a value. It should take a pointer to torch::autograd::AutogradContext as the first argument. Variables can be saved in the ctx using ctx->save_for_backward (see torch::autograd::AutogradContext::save_for_backward) and other data can be saved in the ctx->saved_data map (see torch::autograd::AutogradContext::saved_data) in the form of <std::string, at::IValue> pairs.

backward should take a pointer to torch::autograd::AutogradContext and a variable list containing as many Variables as there were outputs from forward as arguments. It should return as many Variables as there were inputs with each of them containing the gradient w.r.t. its corresponding input. Variables saved in forward can be accessed with ctx->get_saved_variables (see torch::autograd::AutogradContext::get_saved_variables) and other saved data can be accessed from ctx->saved_data.

For example:

class MyFunction : public Function<MyFunction> {
  public:
  static variable_list forward(AutogradContext *ctx, int n, Variable var) {
     // Save data for backward in context
     ctx->saved_data["n"] = n;
     var.mul_(2);
     // Mark var as modified by inplace operation
     ctx->mark_dirty({var});
     return {var};
  }

  static variable_list backward(AutogradContext *ctx, variable_list
  grad_output) {
     // Use data saved in forward
     auto n = ctx->saved_data["n"].toInt();
     return {grad_output[0]*n};
  }
};

To use MyFunction:

Variable x;
auto y = MyFunction::apply(6, x);
// Example backward call
y[0].sum().backward();

Public Static Functions

template<typename X = T, typename ...Args>
auto apply(Args&&... args) -> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>>

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources