Shortcuts

Program Listing for File variable.h

Return to documentation for file (torch/csrc/autograd/variable.h)

#pragma once

#include <torch/csrc/utils/python_stub.h>

#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/cpp_hook.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/function_hook.h>

#include <ATen/NamedTensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/VariableHooksInterface.h>
#include <c10/util/Exception.h>

#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>

namespace torch::autograd {

using Variable = at::Tensor;

} // namespace torch::autograd

// The following are all internal APIs and should not be shown in libtorch docs.
// Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS
// ... #endif`

#ifndef DOXYGEN_SHOULD_SKIP_THIS

namespace torch::autograd {

static inline bool isDifferentiableType(at::ScalarType t) {
  return isFloatingType(t) || isComplexType(t);
}

struct Node;


struct AutogradMeta;
struct DifferentiableViewMeta;

// Private-ish functions for manipulating variables; we don't want to put them
// on Tensor proper
namespace impl {

// WARNING: This may return a nullptr.  If you require AutogradMeta to return
// a materialized structure, use materialize_autograd_meta instead.
TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&);

// WARNING: This will return a nullptr if the Tensor is not a view.
TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&);

// Returns the current autograd meta, materializing it if it was previously
// none.  This counts as a *mutating* operation, so do not call it on
// "read-only" operators; in particular, this is NOT thread safe
TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&);

TORCH_API void set_grad_accumulator(
    const Variable&,
    std::weak_ptr<Node> grad_accumulator);

TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);

TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&);

TORCH_API Edge gradient_edge(const Variable&);

TORCH_API void set_gradient_edge(const Variable&, Edge edge);

// Autograd Graph Interaction
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TORCH_API void rebase_history(const Variable&, Edge gradient_edge);

TORCH_API Node* grad_fn_unsafe(const Variable&);

TORCH_API void bump_version(const Variable&);
TORCH_API void set_version_counter(
    const Variable&,
    const c10::VariableVersion& version_counter);

TORCH_API const c10::VariableVersion& version_counter(const Variable&);

TORCH_API void set_name(const Variable&, const std::string& name);

TORCH_API void add_hook(
    const at::TensorBase&,
    std::unique_ptr<FunctionPreHook> hook);
TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
TORCH_API void clear_hooks(const at::TensorBase&);

TORCH_API void set_post_acc_grad_hooks(
    const at::TensorBase&,
    std::unique_ptr<PostAccumulateGradHook> dict);
TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
    const Variable&);

TORCH_API void create_cpp_hook(
    const at::TensorBase&,
    bool is_retains_grad_hooks = false);
} // namespace impl

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//                            AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
  std::string name_;

  Variable grad_;
  std::shared_ptr<Node> grad_fn_;
  std::weak_ptr<Node> grad_accumulator_;

  // This field is used to store all the forward AD gradients
  // associated with this AutogradMeta (and the Tensor it corresponds to)
  // There is a semantic 1:1 correspondence between AutogradMeta and
  // ForwardGrad but:
  //   - This field is lazily populated.
  //   - This field is a shared_ptr but it must never be
  //     shared by multiple Tensors. See Note [ Using ForwardGrad ]
  // Any transition from not_initialized to initialized
  // must be protected by mutex_
  mutable std::shared_ptr<ForwardGrad> fw_grad_;

  // The hooks_ field is actually reused by both python and cpp logic
  // For both cases, we have a data structure, cpp_hooks_list_ (cpp)
  // or dict (python) which is the canonical copy.
  // Then, for both cases, we always register a single hook to
  // hooks_ which wraps all the hooks in the list/dict.
  // And, again in both cases, if the grad_fn exists on that tensor
  // we will additionally register a single hook to the grad_fn.
  //
  // Note that the cpp and python use cases aren't actually aware of
  // each other, so using both is not defined behavior.
  std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
  std::shared_ptr<hooks_list> cpp_hooks_list_;

  // The post_acc_grad_hooks_ field stores only Python hooks
  // (PyFunctionTensorPostAccGradHooks) that are called after the
  // .grad field has been accumulated into. This is less complicated
  // than the hooks_ field, which encapsulates a lot more.
  std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;

  // Only meaningful on leaf variables (must be false otherwise)
  bool requires_grad_{false};

  // Only meaningful on non-leaf variables (must be false otherwise)
  bool retains_grad_{false};

  bool is_view_{false};

  // The "output number" of this variable; e.g., if this variable
  // was the second output of a function, then output_nr == 1.
  // We use this to make sure we can setup the backwards trace
  // correctly when this variable is passed to another function.
  uint32_t output_nr_;

  // Mutex to ensure that concurrent read operations that modify internal
  // state are still thread-safe. Used by grad_fn(), grad_accumulator(),
  // fw_grad() and set_fw_grad()
  // This is mutable because we need to be able to acquire this from const
  // version of this class for the functions above
  mutable std::mutex mutex_;

  void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {
    TORCH_CHECK(
        !requires_grad ||
            isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),
        "Only Tensors of floating point and complex dtype can require gradients");
    requires_grad_ = requires_grad;
  }

  bool requires_grad() const override {
    return requires_grad_ || grad_fn_;
  }

  Variable& mutable_grad() override {
    return grad_;
  }

  const Variable& grad() const override {
    return grad_;
  }

  const Variable& fw_grad(uint64_t level, const at::TensorBase& self)
      const override;

  void set_fw_grad(
      const at::TensorBase& new_grad,
      const at::TensorBase& self,
      uint64_t level,
      bool is_inplace_op) override;

  AutogradMeta(
      at::TensorImpl* self_impl = nullptr,
      bool requires_grad = false,
      Edge gradient_edge = Edge())
      : grad_fn_(std::move(gradient_edge.function)),

        output_nr_(gradient_edge.input_nr) {
    // set_requires_grad also checks error conditions.
    if (requires_grad) {
      TORCH_INTERNAL_ASSERT(self_impl);
      set_requires_grad(requires_grad, self_impl);
    }
    TORCH_CHECK(
        !grad_fn_ || !requires_grad_,
        "requires_grad should be false if grad_fn is set");
  }

  ~AutogradMeta() override {
    // If AutogradMeta is being destroyed, it means that there is no other
    // reference to its corresponding Tensor. It implies that no other thread
    // can be using this object and so there is no need to lock mutex_ here to
    // guard the check if fw_grad_ is populated.
    if (fw_grad_) {
      // See note [ Using ForwardGrad ]
      fw_grad_->clear();
    }
  }
};

struct TORCH_API ViewFunc {
  virtual ~ViewFunc() = default;
  virtual std::vector<c10::SymInt> get_symints() const {
    return {};
  }
  virtual size_t num_symints() const {
    return 0;
  }
  virtual std::vector<at::Tensor> get_tensors() const {
    return {};
  }
  virtual size_t num_tensors() const {
    return 0;
  }
  virtual at::Tensor operator()(const at::Tensor&) const = 0;
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = std::nullopt,
      std::optional<std::vector<at::Tensor>> = std::nullopt) const = 0;

 protected:
  virtual void set_symints(std::vector<c10::SymInt>) {}
  virtual void set_tensors(std::vector<at::Tensor>) {}
};

struct ChainedViewFunc : public ViewFunc {
  ChainedViewFunc(
      std::unique_ptr<ViewFunc> first,
      std::unique_ptr<ViewFunc> second)
      : first(std::move(first)), second(std::move(second)) {}
  ~ChainedViewFunc() override = default;
  std::vector<c10::SymInt> get_symints() const override;
  size_t num_symints() const override {
    return first->num_symints() + second->num_symints();
  }
  std::vector<at::Tensor> get_tensors() const override;
  size_t num_tensors() const override {
    return first->num_tensors() + second->num_tensors();
  }
  at::Tensor operator()(const at::Tensor&) const override;
  std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = std::nullopt,
      std::optional<std::vector<at::Tensor>> = std::nullopt) const override;

 private:
  std::unique_ptr<ViewFunc> first;
  std::unique_ptr<ViewFunc> second;
};

struct ErroringViewFunc : public ViewFunc {
  ErroringViewFunc(std::string error_msg) : error_msg(std::move(error_msg)) {}
  ~ErroringViewFunc() override = default;
  at::Tensor operator()(const at::Tensor&) const override {
    TORCH_CHECK(false, error_msg);
  }
  std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = std::nullopt,
      std::optional<std::vector<at::Tensor>> = std::nullopt) const override {
    return std::make_unique<ErroringViewFunc>(error_msg);
  }

 private:
  std::string error_msg;
};

struct TORCH_API ViewInfo {
  Variable base_;

  std::unique_ptr<ViewFunc> view_fn_;

  std::function<Variable(const Variable&)> rev_view_fn_;

  bool has_view_fn() const {
    // assume either BOTH or NEITHER of view_fn_ and rev_view_fn_ exist
    return view_fn_ != nullptr;
  }

  const ViewFunc& view_fn() const {
    TORCH_CHECK(
        has_view_fn(), "Can only access the view function if it exists.");
    return *view_fn_;
  }

  std::function<Variable(const Variable&)> rev_view_fn() const {
    TORCH_CHECK(
        has_view_fn(),
        "Can only access the reverse view function if it exists.");
    return rev_view_fn_;
  }

  ViewInfo chain(
      const Variable& base,
      const Variable& tensor,
      std::unique_ptr<ViewFunc> view_func = nullptr,
      std::function<Variable(const Variable&)> rev_view_func = nullptr) const;

  ViewInfo(
      Variable base,
      std::unique_ptr<ViewFunc> view_fn,
      std::function<Variable(const Variable&)> rev_view_fn)
      : base_(std::move(base)),
        view_fn_(std::move(view_fn)),
        rev_view_fn_(std::move(rev_view_fn)) {
    TORCH_CHECK(base_.defined(), "base is undefined");
  }
};

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//                     DifferentiableViewMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



enum class CreationMeta : uint8_t {
  DEFAULT,
  IN_CUSTOM_FUNCTION,
  MULTI_OUTPUT_NODE,
  NO_GRAD_MODE,
  INFERENCE_MODE
};

inline CreationMeta propagate_creation_meta(
    CreationMeta prev_view_creation_meta,
    CreationMeta new_view_creation_meta) {
  return (new_view_creation_meta == CreationMeta::DEFAULT)
      ? prev_view_creation_meta
      : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE
             ? prev_view_creation_meta
             : new_view_creation_meta);
}

TORCH_API void handle_view_on_rebase(
    DifferentiableViewMeta* diff_view_meta,
    bool indirect = false);

struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
 private:
  std::optional<ViewInfo> backward_info_;
  std::optional<ViewInfo> forward_info_;

  // Optimization to reduce the number of ViewInfo we create.
  // In the (very common) case where backward_info_ == forward_info_, we only
  // populate backward_info_ (that should be used as both the forward and
  // backward view information) and set shared_view_info_ = true. Invariants:
  //   - If shared_view_info_ is false, there is no special constraints on
  //     backward_info_ and forward_info_
  //   - If shared_view_info_ is true, we must have:
  //      - backward_info_.has_value() == true
  //      - forward_info_.has_value() == false
  bool shared_view_info_;


  uint32_t attr_version_;
  CreationMeta creation_meta_;

 public:
  bool requires_grad() const override {
    return requires_grad_ || grad_fn_ ||
        (has_bw_view() && get_backward_view().base_.requires_grad());
  }

  bool shared_view_info() const {
    return shared_view_info_;
  }

  bool has_bw_view() const {
    return backward_info_.has_value();
  }

  const ViewInfo& get_backward_view() const {
    TORCH_CHECK(
        has_bw_view(), "backward view info can only exist for backward views.");
    return backward_info_.value();
  }

  uint32_t get_attr_version() const {
    TORCH_CHECK(
        has_bw_view(), "attr_version can only exist for backward views.");
    return attr_version_;
  }

  void set_attr_version(uint32_t new_attr_version) {
    TORCH_CHECK(
        has_bw_view(), "attr_version can only exist for backward views.");
    attr_version_ = new_attr_version;
  }

  CreationMeta get_creation_meta() const {
    TORCH_CHECK(
        has_bw_view(), "creation_meta can only exist for backward views.");
    return creation_meta_;
  }

  void set_creation_meta(CreationMeta new_creation_meta) {
    TORCH_CHECK(
        has_bw_view(), "creation_meta can only exist for backward views.");
    creation_meta_ = new_creation_meta;
  }

  bool has_fw_view() const {
    return shared_view_info_ || forward_info_.has_value();
  }

  const ViewInfo& get_forward_view() const {
    TORCH_CHECK(
        has_fw_view(), "forward view info can only exist for forward views.");
    TORCH_CHECK(
        !shared_view_info_ || has_bw_view(),
        "forward view info can only exist for forward views.");
    return shared_view_info_ ? backward_info_.value() : forward_info_.value();
  }

  DifferentiableViewMeta(
      at::TensorImpl* self_impl,
      std::optional<ViewInfo> backward_info,
      std::optional<ViewInfo> forward_info,
      bool shared_view_info,
      CreationMeta creation_meta = CreationMeta::DEFAULT);
};

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//                        Variable Implementation
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



// See NOTE [ Autograd View Variables ] for details.
// Differentiable view. Track history with DifferentiableViewMeta.
inline Variable make_variable_differentiable_view(
    const at::Tensor& data,
    std::optional<ViewInfo> backward_info,
    std::optional<ViewInfo> forward_info,
    bool shared_view_info,
    CreationMeta creation_meta,
    bool allow_tensor_metadata_change = true) {
  if (data.defined()) {
    TORCH_CHECK(
        data.getIntrusivePtr()->autograd_meta() == nullptr,
        "Attempted to make a tensor into a differentiable view, but the "
        "tensor already had autograd metadata associated with it.  If you are "
        "using a __torch_dispatch__ mode, the most common cause for this "
        "problem is that you used torch.overrides.enable_reentrant_dispatch() "
        "improperly; tensors created within the extent of reentrant dispatch "
        "MUST NOT be directly returned from __torch_dispatch__; instead, they "
        "must be wrapped into fresh tensors that serve as the output.  If you "
        "are not using wrappers, you probably don't need reentrant dispatch.  "
        "If this doesn't seem applicable, please file a bug to PyTorch.");
    at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
    data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
    data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
        data_impl,
        std::move(backward_info),
        std::move(forward_info),
        shared_view_info,
        creation_meta));
    return data;
  }
  return Variable();
}

// See NOTE [ Autograd View Variables ] for details.
// Non-differentiable view. Just share version counter.
inline Variable make_variable_non_differentiable_view(
    const Variable& base,
    const at::Tensor& data,
    bool allow_tensor_metadata_change = true) {
  if (data.defined()) {
    // Currently all of non-differentiable view ops(detach/_indices/_values)
    // share the same TensorImpl as their base Tensor. Thus a new TensorImpl
    // allocation here is required.
    auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
        /*version_counter=*/impl::version_counter(base),
        /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
    data_impl_copy->set_autograd_meta(nullptr);
    return Variable(data_impl_copy);
  }
  return Variable();
}

inline Variable make_variable(
    at::Tensor data,
    bool requires_grad = false,
    bool allow_tensor_metadata_change = true) {
  if (data.defined()) {
    if (data.getIntrusivePtr().use_count() == 1 &&
        data.getIntrusivePtr()->unique_version()) {
      auto data_impl = data.unsafeReleaseIntrusivePtr();
      data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
      if (requires_grad) {
        data_impl->set_autograd_meta(
            std::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
      } else {
        data_impl->set_autograd_meta(nullptr);
      }
      return Variable(std::move(data_impl));
    } else {
      auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
          /*version_counter=*/0,
          /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
      if (requires_grad) {
        data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
            data_impl_copy.get(), requires_grad));
      } else {
        data_impl_copy->set_autograd_meta(nullptr);
      }
      return Variable(data_impl_copy);
    }
  }
  return Variable();
}

inline Variable make_variable(
    const at::Tensor& data,
    Edge gradient_edge,
    bool allow_tensor_metadata_change = true) {
  if (data.defined()) {
    auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
        /*version_counter=*/0,
        /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
    data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
        data_impl_copy.get(), false, std::move(gradient_edge)));
    return Variable(data_impl_copy);
  }
  return Variable();
}

struct VariableHooks final : at::impl::VariableHooksInterface {
  at::TensorBase tensor_data(const at::TensorBase&) const override;
  at::TensorBase variable_data(const at::TensorBase&) const override;
  const std::shared_ptr<torch::autograd::Node>& grad_fn(
      const at::TensorBase&) const override;
  unsigned _register_hook(
      const at::TensorBase&,
      std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
  void remove_hook(const at::TensorBase&, unsigned pos) const override;
  bool is_view(const at::TensorBase&) const override;
  const at::TensorBase& base(const at::TensorBase&) const override;
  const std::string& name(const at::TensorBase&) const override;
  bool is_leaf(const at::TensorBase&) const override;
  int64_t output_nr(const at::TensorBase&) const override;
  void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
      const override;
  at::TensorBase data(const at::TensorBase& self) const override;
  int64_t _version(const at::TensorBase& self) const override;
  void retain_grad(const at::TensorBase& self) const override;
  bool retains_grad(const at::TensorBase& self) const override;
  void _backward(
      const at::Tensor& self,
      at::TensorList inputs,
      const std::optional<at::Tensor>& gradient,
      std::optional<bool> keep_graph,
      bool create_graph) const override;
  void requires_grad_(const at::TensorBase& self, bool _requires_grad)
      const override;
  void basic_autograd_not_implemented_fallback(
      const c10::OperatorHandle& op,
      c10::DispatchKeySet dispatch_keys,
      torch::jit::Stack* stack) const override;
};

namespace utils {

TORCH_API bool has_same_meta(const Variable& base, const Variable& other);

} // namespace utils
} // namespace torch::autograd

#endif /* DOXYGEN_SHOULD_SKIP_THIS */

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources