Shortcuts

Program Listing for File custom_function.h

Return to documentation for file (torch/csrc/autograd/custom_function.h)

#pragma once

#include <ATen/core/ivalue.h>
#include <c10/core/SymInt.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/variable_info.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
#include <vector>

namespace torch::autograd {

using optional_variable_list = std::vector<std::optional<Variable>>;
using _jvp_fn_t = std::function<variable_list(variable_list, variable_list)>;
using _view_as_self_fn_t = std::function<at::Tensor(at::Tensor)>;

TORCH_API std::vector<std::optional<Variable>> _wrap_outputs(
    const variable_list& input_vars,
    const std::unordered_set<at::TensorImpl*>& non_differentiable,
    const std::unordered_set<at::TensorImpl*>& dirty_inputs,
    const at::ArrayRef<std::optional<Variable>> raw_outputs,
    const std::shared_ptr<Node>& cdata,
    const _jvp_fn_t& jvp_user_function,
    const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
    const _view_as_self_fn_t& view_as_self_fn);

TORCH_API void check_variable_result(
    const at::TensorBase& original,
    const at::TensorBase& result,
    const std::string& hook_name);

// Get the return type of the forward function of the custom Function class X
template <typename X, typename... Args>
using forward_t = decltype(X::forward(nullptr, std::declval<Args>()...));

template <class T>
struct TORCH_API Function {
  // We need to use a different template parameter than T here because T will
  // inherit from Function, and when Function<T> is instantiated, T::forward
  // is not declared yet.
  // The enable_if check is to ensure that the user doesn't explicitly provide
  // the parameter X.
  template <typename X = T, typename... Args>
  static auto apply(Args&&... args)
      -> std::enable_if_t<std::is_same_v<X, T>, forward_t<X, Args...>>;

  // This flag is for an experimental feature: compiled autograd. Not all
  // built-in APIs are supported at the moment e.g. mark_dirty and
  // mark_non_differentiable. Before setting this flag to enable tracing for
  // your custom function <T>, you need to ensure that the backward function is
  // traceable i.e. any variables accessed in the backward other than the input
  // arguments must be handled in a similar manner to built-ins in
  // CppNode::compiled_args and CppNode::apply_with_saved.
  static constexpr bool is_traceable = false;
};

struct TORCH_API AutogradContext {
  AutogradContext() = default;
  AutogradContext(const AutogradContext& other) = delete;
  AutogradContext& operator=(const AutogradContext& other) = delete;
  AutogradContext(AutogradContext&& other) = delete;
  AutogradContext& operator=(AutogradContext&& other) = delete;
  ~AutogradContext() = default;

  ska::flat_hash_map<std::string, at::IValue> saved_data;

  void save_for_backward(variable_list to_save);
  void mark_dirty(const variable_list& inputs);
  void mark_non_differentiable(const variable_list& outputs);
  // Sets whether undefined output grad tensors should be expanded to tensors
  // full of zeros before calling backward function. Default value is true.
  void set_materialize_grads(bool value);

  variable_list get_saved_variables() const;
  const std::unordered_set<at::TensorImpl*>& get_and_bump_dirty() const;
  const std::unordered_set<at::TensorImpl*>& get_non_differentiable() const;

  bool needs_input_grad(size_t output_edge_index) const;
  bool needs_input_grad(std::initializer_list<IndexRange> idxs) const;

 private:
  std::unordered_set<at::TensorImpl*> non_differentiable_;
  std::unordered_set<at::TensorImpl*> dirty_inputs_;
  std::vector<torch::autograd::SavedVariable> saved_variables_;
  variable_list to_save_;
  bool materialize_grads_{true};

  // The CppNode in the autograd graph that owns this AutogradContext. We need a
  // weak_ptr to avoid a refcycle. Since grad_fn_ owns this AutogradContext, it
  // will always be alive when we want to use it.
  std::weak_ptr<Node> grad_fn_;
  bool has_freed_buffers_{false};

  void save_variables();

  template <class T>
  friend struct CppNode;
};

// CppNode<T> is the Node in the autograd graph that represents the user defined
// backward function for Function<T>. Calls to CppNode::apply are forward to
// T::backward().
template <class T>
struct CppNode : public Node {
  variable_list apply(variable_list&& inputs) override;
  AutogradContext ctx_;
  std::vector<bool> is_variable_input_;
  std::vector<VariableInfo> input_info_;
  std::vector<VariableInfo> output_info_;

  void release_variables() override;

  void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
  void save_variables_to_ctx();

  void compiled_args(CompiledNodeArgs& args) override {
    static_assert(
        std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>);
    if (!T::is_traceable) {
      throw std::runtime_error(
          std::string(
              "Attempting to trace a potentially unsafe C++ autograd function: ") +
          name() +
          ". It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.");
    }

    // although neither of the 2 methods below have uniqueness guarantees
    // it is unlikely for them to collide at the same time
    args.collect(static_cast<uint64_t>(typeid(T).hash_code()));
    args.collect(std::string(typeid(T).name()));

    args.collect(ctx_.saved_data);
    TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
    TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
    args.collect(
        ctx_.saved_variables_, true); // always unpacked as output in eager
    TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
    args.collect(ctx_.materialize_grads_);
    args.collect(ctx_.has_freed_buffers_);
    args.collect(is_variable_input_);
    args.collect(input_info_);
    args.collect(output_info_);
  }

  variable_list apply_with_saved(
      const variable_list& inputs,
      SwapSavedVariables& saved) override {
    saved.before(ctx_.saved_data);
    TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
    TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
    saved.before(ctx_.saved_variables_);
    TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
    saved.before(ctx_.materialize_grads_);
    saved.before(ctx_.has_freed_buffers_);
    saved.before(input_info_);
    saved.before(output_info_);
    auto results = apply(variable_list(inputs));
    saved.after(ctx_.saved_data);
    TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
    TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
    saved.after(ctx_.saved_variables_);
    TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
    saved.after(ctx_.materialize_grads_);
    saved.after(ctx_.has_freed_buffers_);
    saved.after(input_info_);
    saved.after(output_info_);
    return results;
  }
};

struct ExtractVariables : IterArgs<ExtractVariables> {
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  std::vector<bool>& is_var_;
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  variable_list& list_;
  ExtractVariables(std::vector<bool>& is_var, variable_list& list)
      : is_var_(is_var), list_(list) {}
  void operator()(const std::optional<at::Tensor>& x) {
    if (x.has_value() && x.value().defined()) {
      is_var_.push_back(true);
      list_.emplace_back(x.value());
    } else {
      is_var_.push_back(false);
    }
  }
  void operator()(const at::Tensor& x) {
    is_var_.push_back(true);
    list_.emplace_back(x);
  }
  void operator()(const at::TensorList& list) {
    for (const at::Tensor& x : list) {
      is_var_.push_back(true);
      list_.emplace_back(x);
    }
  }
  template <typename T>
  void operator()(const T& x) {
    is_var_.push_back(false);
  }
};

template <typename... Args>
inline void extract_vars(
    std::vector<bool>& is_var,
    variable_list& list,
    Args&&... args) {
  ExtractVariables(is_var, list).apply(std::forward<Args>(args)...);
}

template <typename T>
std::enable_if_t<std::is_same_v<T, variable_list>, T> to_output_type(
    std::vector<std::optional<Variable>>& output_list) {
  variable_list result;
  std::transform(
      output_list.begin(),
      output_list.end(),
      std::back_inserter(result),
      [](const std::optional<Variable>& var) { return *var; });
  return result;
}

template <typename T>
std::enable_if_t<std::is_same_v<T, Variable>, T> to_output_type(
    std::vector<std::optional<Variable>>& output_list) {
  return *output_list[0];
}

inline std::vector<std::optional<Variable>> to_optional(Variable& output) {
  return std::vector<std::optional<Variable>>{output};
}

inline std::vector<std::optional<Variable>> to_optional(variable_list& output) {
  std::vector<std::optional<Variable>> result;
  std::transform(
      output.begin(),
      output.end(),
      std::back_inserter(result),
      [](const Variable& var) { return var; });
  return result;
}

template <class T>
template <typename X, typename... Args>
auto Function<T>::apply(Args&&... args)
    -> std::enable_if_t<std::is_same_v<X, T>, forward_t<X, Args...>> {
  const auto& functorch_tls = at::functorch::functorchTLSAccessor();
  if (functorch_tls) {
    // Function support for functorch is handled in Python.
    // Here we are dealing with a (C++) Function, which is not supported.
    // Let's raise an error instead of being silently incorrect.
    functorch_tls->checkSupportsCppAutogradFunction();
  }

  std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);
  variable_list input_vars;

  const size_t num_inputs = sizeof...(Args);
  input_vars.reserve(num_inputs);
  node->is_variable_input_.reserve(num_inputs);
  // TODO Add tracing here
  extract_vars(node->is_variable_input_, input_vars, args...);

  bool is_executable =
      GradMode::is_enabled() && any_variable_requires_grad(input_vars);
  auto next_edges =
      (is_executable ? collect_next_edges(input_vars) : edge_list());
  node->set_ctx_grad_fn(node);
  node->set_next_edges(std::move(next_edges));
  node->clear_input_metadata();

  node->input_info_.reserve(input_vars.size());
  for (auto& var : input_vars) {
    node->input_info_.emplace_back(var);
  }

  using forward_return_t = forward_t<X, Args...>;
  forward_return_t outputs;
  {
    AutoGradMode grad_mode(false);
    outputs = T::forward(&node->ctx_, std::forward<Args>(args)...);
  }

  _jvp_fn_t jvp_fn = [](const variable_list& inputs,
                        const variable_list& gI) -> variable_list {
    TORCH_CHECK(
        false,
        "jvp is not implemented for the c++ API of custom Function yet.",
        "Please open a feature request on GitHub if you need this.");
  };

  auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
    return x.view_as(x);
  };

  auto wrapped_outputs = _wrap_outputs(
      input_vars,
      node->ctx_.get_non_differentiable(),
      node->ctx_.get_and_bump_dirty(),
      to_optional(outputs),
      is_executable ? node : nullptr,
      jvp_fn,
      {},
      view_as_self_fn);

  node->output_info_.reserve(wrapped_outputs.size());
  for (auto& output : wrapped_outputs) {
    if (is_executable && output.has_value()) {
      node->output_info_.emplace_back(output.value());
    } else if (is_executable) {
      node->output_info_.emplace_back();
    }
  }

  if (is_executable) {
    node->save_variables_to_ctx();
  }

  // wrapped_outputs will be a variable_list so, convert it to the correct
  // return type. Only Variable and variable_list are accepted as return types.
  return to_output_type<forward_return_t>(wrapped_outputs);
}

// The logic here is the same as PyNode::apply, so changes to it should be done
// in both the places
template <class T>
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
variable_list CppNode<T>::apply(variable_list&& inputs) {
  at::OptionalDeviceGuard _device_guard;

  auto num_inputs = inputs.size();
  variable_list backward_inputs;
  backward_inputs.reserve(num_inputs);
  for (const auto i : c10::irange(num_inputs)) {
    if (inputs[i].defined() || !ctx_.materialize_grads_) {
      backward_inputs.emplace_back(std::move(inputs[i]));
    } else {
      backward_inputs.emplace_back(output_info_[i].zeros(_device_guard));
    }
  }

  // Acquire lock to here protect thread safety on custom C++ Autograd Node
  // This is needed for the custom Autograd Node since we don't know if the
  // user defined Node will write to the shared data during backward.
  // see Note [Thread Safety on Autograd Node]
  std::lock_guard<std::mutex> lock(mutex_);

  auto outputs = T::backward(&ctx_, backward_inputs);

  const auto num_forward_inputs =
      static_cast<int64_t>(is_variable_input_.size());
  auto num_outputs = static_cast<int64_t>(outputs.size());
  // Returning too many results is ok, but only as long as they're all
  // undefined. Truncate the result vector in that case.
  if (num_outputs > num_forward_inputs) {
    bool all_undef = true;
    for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
      all_undef &= (!outputs[i].defined());
    }
    if (all_undef) {
      outputs.resize(num_forward_inputs);
      num_outputs = num_forward_inputs;
    }
  }

  if (num_outputs != num_forward_inputs) {
    std::string msg("function ");
    msg += name() + " returned an incorrect number of gradients (expected ";
    msg += std::to_string(num_forward_inputs) + ", got ";
    msg += std::to_string(num_outputs) + ")";
    throw std::runtime_error(msg);
  }

  variable_list results;
  results.reserve(num_outputs);
  for (const auto i : c10::irange(num_outputs)) {
    if (!is_variable_input_[i]) {
      if (outputs[i].defined()) {
        std::string msg("function ");
        msg += name() +
            " returned a gradient different that is defined at position ";
        msg += std::to_string(i + 1) +
            ", std the corresponding forward input was not a Variable";
        throw std::runtime_error(msg);
      }
      continue;
    }
    results.emplace_back(outputs[i]);
  }
  return results;
}

template <class T>
void CppNode<T>::release_variables() {
  // lock to ensure thread safety, see [Thread Safety on Autograd Node]
  std::lock_guard<std::mutex> lock(mutex_);
  ctx_.saved_variables_.clear();
  ctx_.has_freed_buffers_ = true;
}

template <class T>
void CppNode<T>::save_variables_to_ctx() {
  ctx_.save_variables();
}

template <class T>
void CppNode<T>::set_ctx_grad_fn(const std::shared_ptr<Node>& node) {
  ctx_.grad_fn_ = node;
}

} // namespace torch::autograd

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources