Template Struct Function¶
Defined in File custom_function.h
Page Contents
Struct Documentation¶
-
template<class T>
struct Function¶ To use custom autograd operations, implement a Function subclass with static forward and backward functions:
forward
can take as many arguments as you want and should return either a variable list or a Variable. Use of any direct Variable arguments will be registered in the graph but no vectors/sets or any other data structures will be traversed. You can use std::optional<Tensor> as one of the arguments and it will be registered as a variable in the graph if the argument has a value. It should take a pointer totorch::autograd::AutogradContext
as the first argument. Variables can be saved in thectx
usingctx->save_for_backward
(seetorch::autograd::AutogradContext::save_for_backward
) and other data can be saved in thectx->saved_data
map (seetorch::autograd::AutogradContext::saved_data
) in the form of<std::string, at::IValue>
pairs.backward
should take a pointer totorch::autograd::AutogradContext
and a variable list containing as many Variables as there were outputs fromforward
as arguments. It should return as many Variables as there were inputs with each of them containing the gradient w.r.t. its corresponding input. Variables saved inforward
can be accessed withctx->get_saved_variables
(seetorch::autograd::AutogradContext::get_saved_variables
) and other saved data can be accessed fromctx->saved_data
. To enable compiled autograd support (torch.compile for backward) for your custom autograd operation, you can set MyFunction::is_traceable (see Function::istraceable notes below).For example:
class MyFunction : public Function<MyFunction> { public: static constexpr bool is_traceable = true; static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; var.mul_(n); // Mark var as modified by inplace operation ctx->mark_dirty({var}); return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } };
To use
MyFunction
:Variable x; auto y = MyFunction::apply(6, x); // Example backward call y[0].sum().backward();
Public Static Functions
Public Static Attributes
-
static constexpr bool is_traceable = false¶
-
static constexpr bool is_traceable = false¶