Program Listing for File variable.h¶
↰ Return to documentation for file (torch/csrc/autograd/variable.h
)
#pragma once
#include <torch/csrc/utils/python_stub.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/cpp_hook.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/function_hook.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/VariableHooksInterface.h>
#include <c10/util/Exception.h>
#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
namespace torch::autograd {
using Variable = at::Tensor;
} // namespace torch::autograd
// The following are all internal APIs and should not be shown in libtorch docs.
// Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS
// ... #endif`
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace torch::autograd {
static inline bool isDifferentiableType(at::ScalarType t) {
return isFloatingType(t) || isComplexType(t);
}
struct Node;
struct AutogradMeta;
struct DifferentiableViewMeta;
// Private-ish functions for manipulating variables; we don't want to put them
// on Tensor proper
namespace impl {
// WARNING: This may return a nullptr. If you require AutogradMeta to return
// a materialized structure, use materialize_autograd_meta instead.
TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&);
// WARNING: This will return a nullptr if the Tensor is not a view.
TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&);
// Returns the current autograd meta, materializing it if it was previously
// none. This counts as a *mutating* operation, so do not call it on
// "read-only" operators; in particular, this is NOT thread safe
TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&);
TORCH_API void set_grad_accumulator(
const Variable&,
std::weak_ptr<Node> grad_accumulator);
TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&);
TORCH_API Edge gradient_edge(const Variable&);
TORCH_API void set_gradient_edge(const Variable&, Edge edge);
// Autograd Graph Interaction
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TORCH_API void rebase_history(const Variable&, Edge gradient_edge);
TORCH_API Node* grad_fn_unsafe(const Variable&);
TORCH_API void bump_version(const Variable&);
TORCH_API void set_version_counter(
const Variable&,
const c10::VariableVersion& version_counter);
TORCH_API const c10::VariableVersion& version_counter(const Variable&);
TORCH_API void set_name(const Variable&, const std::string& name);
TORCH_API void add_hook(
const at::TensorBase&,
std::unique_ptr<FunctionPreHook> hook);
TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
TORCH_API void clear_hooks(const at::TensorBase&);
TORCH_API void set_post_acc_grad_hooks(
const at::TensorBase&,
std::unique_ptr<PostAccumulateGradHook> dict);
TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
const Variable&);
TORCH_API void create_cpp_hook(
const at::TensorBase&,
bool is_retains_grad_hooks = false);
} // namespace impl
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
std::string name_;
Variable grad_;
std::shared_ptr<Node> grad_fn_;
std::weak_ptr<Node> grad_accumulator_;
// This field is used to store all the forward AD gradients
// associated with this AutogradMeta (and the Tensor it corresponds to)
// There is a semantic 1:1 correspondence between AutogradMeta and
// ForwardGrad but:
// - This field is lazily populated.
// - This field is a shared_ptr but it must never be
// shared by multiple Tensors. See Note [ Using ForwardGrad ]
// Any transition from not_initialized to initialized
// must be protected by mutex_
mutable std::shared_ptr<ForwardGrad> fw_grad_;
// The hooks_ field is actually reused by both python and cpp logic
// For both cases, we have a data structure, cpp_hooks_list_ (cpp)
// or dict (python) which is the canonical copy.
// Then, for both cases, we always register a single hook to
// hooks_ which wraps all the hooks in the list/dict.
// And, again in both cases, if the grad_fn exists on that tensor
// we will additionally register a single hook to the grad_fn.
//
// Note that the cpp and python use cases aren't actually aware of
// each other, so using both is not defined behavior.
std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
std::shared_ptr<hooks_list> cpp_hooks_list_;
// The post_acc_grad_hooks_ field stores only Python hooks
// (PyFunctionTensorPostAccGradHooks) that are called after the
// .grad field has been accumulated into. This is less complicated
// than the hooks_ field, which encapsulates a lot more.
std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
// Only meaningful on leaf variables (must be false otherwise)
bool requires_grad_{false};
// Only meaningful on non-leaf variables (must be false otherwise)
bool retains_grad_{false};
bool is_view_{false};
// The "output number" of this variable; e.g., if this variable
// was the second output of a function, then output_nr == 1.
// We use this to make sure we can setup the backwards trace
// correctly when this variable is passed to another function.
uint32_t output_nr_;
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn(), grad_accumulator(),
// fw_grad() and set_fw_grad()
// This is mutable because we need to be able to acquire this from const
// version of this class for the functions above
mutable std::mutex mutex_;
void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {
TORCH_CHECK(
!requires_grad ||
isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),
"Only Tensors of floating point and complex dtype can require gradients");
requires_grad_ = requires_grad;
}
bool requires_grad() const override {
return requires_grad_ || grad_fn_;
}
Variable& mutable_grad() override {
return grad_;
}
const Variable& grad() const override {
return grad_;
}
const Variable& fw_grad(uint64_t level, const at::TensorBase& self)
const override;
void set_fw_grad(
const at::TensorBase& new_grad,
const at::TensorBase& self,
uint64_t level,
bool is_inplace_op) override;
AutogradMeta(
at::TensorImpl* self_impl = nullptr,
bool requires_grad = false,
Edge gradient_edge = Edge())
: grad_fn_(std::move(gradient_edge.function)),
output_nr_(gradient_edge.input_nr) {
// set_requires_grad also checks error conditions.
if (requires_grad) {
TORCH_INTERNAL_ASSERT(self_impl);
set_requires_grad(requires_grad, self_impl);
}
TORCH_CHECK(
!grad_fn_ || !requires_grad_,
"requires_grad should be false if grad_fn is set");
}
~AutogradMeta() override {
// If AutogradMeta is being destroyed, it means that there is no other
// reference to its corresponding Tensor. It implies that no other thread
// can be using this object and so there is no need to lock mutex_ here to
// guard the check if fw_grad_ is populated.
if (fw_grad_) {
// See note [ Using ForwardGrad ]
fw_grad_->clear();
}
}
};
struct TORCH_API ViewFunc {
virtual ~ViewFunc() = default;
virtual std::vector<c10::SymInt> get_symints() const {
return {};
}
virtual size_t num_symints() const {
return 0;
}
virtual std::vector<at::Tensor> get_tensors() const {
return {};
}
virtual size_t num_tensors() const {
return 0;
}
virtual at::Tensor operator()(const at::Tensor&) const = 0;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = std::nullopt,
std::optional<std::vector<at::Tensor>> = std::nullopt) const = 0;
protected:
virtual void set_symints(std::vector<c10::SymInt>) {}
virtual void set_tensors(std::vector<at::Tensor>) {}
};
struct ChainedViewFunc : public ViewFunc {
ChainedViewFunc(
std::unique_ptr<ViewFunc> first,
std::unique_ptr<ViewFunc> second)
: first(std::move(first)), second(std::move(second)) {}
~ChainedViewFunc() override = default;
std::vector<c10::SymInt> get_symints() const override;
size_t num_symints() const override {
return first->num_symints() + second->num_symints();
}
std::vector<at::Tensor> get_tensors() const override;
size_t num_tensors() const override {
return first->num_tensors() + second->num_tensors();
}
at::Tensor operator()(const at::Tensor&) const override;
std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = std::nullopt,
std::optional<std::vector<at::Tensor>> = std::nullopt) const override;
private:
std::unique_ptr<ViewFunc> first;
std::unique_ptr<ViewFunc> second;
};
struct ErroringViewFunc : public ViewFunc {
ErroringViewFunc(std::string error_msg) : error_msg(std::move(error_msg)) {}
~ErroringViewFunc() override = default;
at::Tensor operator()(const at::Tensor&) const override {
TORCH_CHECK(false, error_msg);
}
std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = std::nullopt,
std::optional<std::vector<at::Tensor>> = std::nullopt) const override {
return std::make_unique<ErroringViewFunc>(error_msg);
}
private:
std::string error_msg;
};
struct TORCH_API ViewInfo {
Variable base_;
std::unique_ptr<ViewFunc> view_fn_;
std::function<Variable(const Variable&)> rev_view_fn_;
bool has_view_fn() const {
// assume either BOTH or NEITHER of view_fn_ and rev_view_fn_ exist
return view_fn_ != nullptr;
}
const ViewFunc& view_fn() const {
TORCH_CHECK(
has_view_fn(), "Can only access the view function if it exists.");
return *view_fn_;
}
std::function<Variable(const Variable&)> rev_view_fn() const {
TORCH_CHECK(
has_view_fn(),
"Can only access the reverse view function if it exists.");
return rev_view_fn_;
}
ViewInfo chain(
const Variable& base,
const Variable& tensor,
std::unique_ptr<ViewFunc> view_func = nullptr,
std::function<Variable(const Variable&)> rev_view_func = nullptr) const;
ViewInfo(
Variable base,
std::unique_ptr<ViewFunc> view_fn,
std::function<Variable(const Variable&)> rev_view_fn)
: base_(std::move(base)),
view_fn_(std::move(view_fn)),
rev_view_fn_(std::move(rev_view_fn)) {
TORCH_CHECK(base_.defined(), "base is undefined");
}
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// DifferentiableViewMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class CreationMeta : uint8_t {
DEFAULT,
IN_CUSTOM_FUNCTION,
MULTI_OUTPUT_NODE,
NO_GRAD_MODE,
INFERENCE_MODE
};
inline CreationMeta propagate_creation_meta(
CreationMeta prev_view_creation_meta,
CreationMeta new_view_creation_meta) {
return (new_view_creation_meta == CreationMeta::DEFAULT)
? prev_view_creation_meta
: (prev_view_creation_meta == CreationMeta::INFERENCE_MODE
? prev_view_creation_meta
: new_view_creation_meta);
}
TORCH_API void handle_view_on_rebase(
DifferentiableViewMeta* diff_view_meta,
bool indirect = false);
struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
private:
std::optional<ViewInfo> backward_info_;
std::optional<ViewInfo> forward_info_;
// Optimization to reduce the number of ViewInfo we create.
// In the (very common) case where backward_info_ == forward_info_, we only
// populate backward_info_ (that should be used as both the forward and
// backward view information) and set shared_view_info_ = true. Invariants:
// - If shared_view_info_ is false, there is no special constraints on
// backward_info_ and forward_info_
// - If shared_view_info_ is true, we must have:
// - backward_info_.has_value() == true
// - forward_info_.has_value() == false
bool shared_view_info_;
uint32_t attr_version_;
CreationMeta creation_meta_;
public:
bool requires_grad() const override {
return requires_grad_ || grad_fn_ ||
(has_bw_view() && get_backward_view().base_.requires_grad());
}
bool shared_view_info() const {
return shared_view_info_;
}
bool has_bw_view() const {
return backward_info_.has_value();
}
const ViewInfo& get_backward_view() const {
TORCH_CHECK(
has_bw_view(), "backward view info can only exist for backward views.");
return backward_info_.value();
}
uint32_t get_attr_version() const {
TORCH_CHECK(
has_bw_view(), "attr_version can only exist for backward views.");
return attr_version_;
}
void set_attr_version(uint32_t new_attr_version) {
TORCH_CHECK(
has_bw_view(), "attr_version can only exist for backward views.");
attr_version_ = new_attr_version;
}
CreationMeta get_creation_meta() const {
TORCH_CHECK(
has_bw_view(), "creation_meta can only exist for backward views.");
return creation_meta_;
}
void set_creation_meta(CreationMeta new_creation_meta) {
TORCH_CHECK(
has_bw_view(), "creation_meta can only exist for backward views.");
creation_meta_ = new_creation_meta;
}
bool has_fw_view() const {
return shared_view_info_ || forward_info_.has_value();
}
const ViewInfo& get_forward_view() const {
TORCH_CHECK(
has_fw_view(), "forward view info can only exist for forward views.");
TORCH_CHECK(
!shared_view_info_ || has_bw_view(),
"forward view info can only exist for forward views.");
return shared_view_info_ ? backward_info_.value() : forward_info_.value();
}
DifferentiableViewMeta(
at::TensorImpl* self_impl,
std::optional<ViewInfo> backward_info,
std::optional<ViewInfo> forward_info,
bool shared_view_info,
CreationMeta creation_meta = CreationMeta::DEFAULT);
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable Implementation
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// See NOTE [ Autograd View Variables ] for details.
// Differentiable view. Track history with DifferentiableViewMeta.
inline Variable make_variable_differentiable_view(
const at::Tensor& data,
std::optional<ViewInfo> backward_info,
std::optional<ViewInfo> forward_info,
bool shared_view_info,
CreationMeta creation_meta,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
TORCH_CHECK(
data.getIntrusivePtr()->autograd_meta() == nullptr,
"Attempted to make a tensor into a differentiable view, but the "
"tensor already had autograd metadata associated with it. If you are "
"using a __torch_dispatch__ mode, the most common cause for this "
"problem is that you used torch.overrides.enable_reentrant_dispatch() "
"improperly; tensors created within the extent of reentrant dispatch "
"MUST NOT be directly returned from __torch_dispatch__; instead, they "
"must be wrapped into fresh tensors that serve as the output. If you "
"are not using wrappers, you probably don't need reentrant dispatch. "
"If this doesn't seem applicable, please file a bug to PyTorch.");
at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
data_impl,
std::move(backward_info),
std::move(forward_info),
shared_view_info,
creation_meta));
return data;
}
return Variable();
}
// See NOTE [ Autograd View Variables ] for details.
// Non-differentiable view. Just share version counter.
inline Variable make_variable_non_differentiable_view(
const Variable& base,
const at::Tensor& data,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
// Currently all of non-differentiable view ops(detach/_indices/_values)
// share the same TensorImpl as their base Tensor. Thus a new TensorImpl
// allocation here is required.
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/impl::version_counter(base),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(nullptr);
return Variable(data_impl_copy);
}
return Variable();
}
inline Variable make_variable(
at::Tensor data,
bool requires_grad = false,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
if (data.getIntrusivePtr().use_count() == 1 &&
data.getIntrusivePtr()->unique_version()) {
auto data_impl = data.unsafeReleaseIntrusivePtr();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
if (requires_grad) {
data_impl->set_autograd_meta(
std::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
} else {
data_impl->set_autograd_meta(nullptr);
}
return Variable(std::move(data_impl));
} else {
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
if (requires_grad) {
data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
data_impl_copy.get(), requires_grad));
} else {
data_impl_copy->set_autograd_meta(nullptr);
}
return Variable(data_impl_copy);
}
}
return Variable();
}
inline Variable make_variable(
const at::Tensor& data,
Edge gradient_edge,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
data_impl_copy.get(), false, std::move(gradient_edge)));
return Variable(data_impl_copy);
}
return Variable();
}
struct VariableHooks final : at::impl::VariableHooksInterface {
at::TensorBase tensor_data(const at::TensorBase&) const override;
at::TensorBase variable_data(const at::TensorBase&) const override;
const std::shared_ptr<torch::autograd::Node>& grad_fn(
const at::TensorBase&) const override;
unsigned _register_hook(
const at::TensorBase&,
std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
void remove_hook(const at::TensorBase&, unsigned pos) const override;
bool is_view(const at::TensorBase&) const override;
const at::TensorBase& base(const at::TensorBase&) const override;
const std::string& name(const at::TensorBase&) const override;
bool is_leaf(const at::TensorBase&) const override;
int64_t output_nr(const at::TensorBase&) const override;
void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
const override;
at::TensorBase data(const at::TensorBase& self) const override;
int64_t _version(const at::TensorBase& self) const override;
void retain_grad(const at::TensorBase& self) const override;
bool retains_grad(const at::TensorBase& self) const override;
void _backward(
const at::Tensor& self,
at::TensorList inputs,
const std::optional<at::Tensor>& gradient,
std::optional<bool> keep_graph,
bool create_graph) const override;
void requires_grad_(const at::TensorBase& self, bool _requires_grad)
const override;
void basic_autograd_not_implemented_fallback(
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) const override;
};
namespace utils {
TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
} // namespace utils
} // namespace torch::autograd
#endif /* DOXYGEN_SHOULD_SKIP_THIS */