Shortcuts

Template Struct Function

Page Contents

Struct Documentation

template<class T>
struct Function

To use custom autograd operations, implement a Function subclass with static forward and backward functions:

forward can take as many arguments as you want and should return either a variable list or a Variable. Use of any direct Variable arguments will be registered in the graph but no vectors/sets or any other data structures will be traversed. You can use c10::optional<Tensor> as one of the arguments and it will be registered as a variable in the graph if the argument has a value. It should take a pointer to torch::autograd::AutogradContext as the first argument. Variables can be saved in the ctx using ctx->save_for_backward (see torch::autograd::AutogradContext::save_for_backward) and other data can be saved in the ctx->saved_data map (see torch::autograd::AutogradContext::saved_data) in the form of <std::string, at::IValue> pairs.

backward should take a pointer to torch::autograd::AutogradContext and a variable list containing as many Variables as there were outputs from forward as arguments. It should return as many Variables as there were inputs with each of them containing the gradient w.r.t. its corresponding input. Variables saved in forward can be accessed with ctx->get_saved_variables (see torch::autograd::AutogradContext::get_saved_variables) and other saved data can be accessed from ctx->saved_data. To enable compiled autograd support (torch.compile for backward) for your custom autograd operation, you can set MyFunction::is_traceable (see Function::istraceable notes below).

For example:

class MyFunction : public Function<MyFunction> {
  public:
  static constexpr bool is_traceable = true;

  static variable_list forward(AutogradContext *ctx, int n, Variable var) {
     // Save data for backward in context
     ctx->saved_data["n"] = n;
     var.mul_(2);
     // Mark var as modified by inplace operation
     ctx->mark_dirty({var});
     return {var};
  }

  static variable_list backward(AutogradContext *ctx, variable_list
  grad_output) {
     // Use data saved in forward
     auto n = ctx->saved_data["n"].toInt();
     return {grad_output[0]*n};
  }
};

To use MyFunction:

Variable x;
auto y = MyFunction::apply(6, x);
// Example backward call
y[0].sum().backward();

Public Static Functions

template<typename X = T, typename ...Args>
static auto apply(Args&&... args) -> std::enable_if_t<std::is_same_v<X, T>, forward_t<X, Args...>>

Public Static Attributes

static constexpr bool is_traceable = false

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources