Shortcuts

Program Listing for File any_value.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/any_value.h)

#pragma once

#include <torch/types.h>

#include <memory>
#include <type_traits>
#include <typeinfo>
#include <utility>

namespace torch::nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class AnyValue {
 public:
  AnyValue(AnyValue&&) = default;
  AnyValue& operator=(AnyValue&&) = default;

  AnyValue(const AnyValue& other) : content_(other.content_->clone()) {}
  AnyValue& operator=(const AnyValue& other) {
    content_ = other.content_->clone();
    return *this;
  }

  template <
      typename T,
      typename = std::enable_if_t<!std::is_same_v<T, AnyValue>>>
  explicit AnyValue(T&& value)
      : content_(
            std::make_unique<Holder<std::decay_t<T>>>(std::forward<T>(value))) {
  }

  template <typename T>
  T* try_get() {
    static_assert(
        !std::is_reference_v<T>,
        "AnyValue stores decayed types, you cannot cast it to a reference type");
    static_assert(
        !std::is_array_v<T>,
        "AnyValue stores decayed types, you must cast it to T* instead of T[]");
    if (typeid(T).hash_code() == type_info().hash_code()) {
      return &static_cast<Holder<T>&>(*content_).value;
    }
    return nullptr;
  }

  template <typename T>
  T get() {
    if (auto* maybe_value = try_get<T>()) {
      return *maybe_value;
    }
    TORCH_CHECK(
        false,
        "Attempted to cast AnyValue to ",
        c10::demangle(typeid(T).name()),
        ", but its actual type is ",
        c10::demangle(type_info().name()));
  }

  const std::type_info& type_info() const noexcept {
    return content_->type_info;
  }

 private:
  friend struct AnyModulePlaceholder;
  friend struct TestAnyValue;

  struct Placeholder {
    explicit Placeholder(const std::type_info& type_info_) noexcept
        : type_info(type_info_) {}
    Placeholder(const Placeholder&) = default;
    Placeholder(Placeholder&&) = default;
    virtual ~Placeholder() = default;
    virtual std::unique_ptr<Placeholder> clone() const {
      TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");
    }
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
    const std::type_info& type_info;
  };

  template <typename T>
  struct Holder : public Placeholder {
    template <
        typename U,
        typename = std::enable_if_t<!std::is_same_v<U, Holder>>>
    explicit Holder(U&& value_) noexcept
        : Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
    std::unique_ptr<Placeholder> clone() const override {
      return std::make_unique<Holder<T>>(value);
    }
    T value;
  };

  std::unique_ptr<Placeholder> content_;
};

} // namespace torch::nn

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources