Shortcuts

Program Listing for File any_value.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/any_value.h)

#pragma once

#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/variadic.h>

#include <memory>
#include <type_traits>
#include <typeinfo>
#include <utility>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class AnyValue {
 public:
  AnyValue(AnyValue&&) = default;
  AnyValue& operator=(AnyValue&&) = default;

  AnyValue(const AnyValue& other) : content_(other.content_->clone()) {}
  AnyValue& operator=(const AnyValue& other) {
    content_ = other.content_->clone();
    return *this;
  }

  template <typename T>
  // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
  explicit AnyValue(T&& value)
      : content_(std::make_unique<Holder<decay_t<T>>>(std::forward<T>(value))) {
  }

  template <typename T>
  T* try_get() {
    static_assert(
        !std::is_reference<T>::value,
        "AnyValue stores decayed types, you cannot cast it to a reference type");
    static_assert(
        !std::is_array<T>::value,
        "AnyValue stores decayed types, you must cast it to T* instead of T[]");
    if (typeid(T).hash_code() == type_info().hash_code()) {
      return &static_cast<Holder<T>&>(*content_).value;
    }
    return nullptr;
  }

  template <typename T>
  T get() {
    if (auto* maybe_value = try_get<T>()) {
      return *maybe_value;
    }
    AT_ERROR(
        "Attempted to cast AnyValue to ",
        c10::demangle(typeid(T).name()),
        ", but its actual type is ",
        c10::demangle(type_info().name()));
  }

  const std::type_info& type_info() const noexcept {
    return content_->type_info;
  }

 private:
  friend struct AnyModulePlaceholder;
  friend struct TestAnyValue;

  struct Placeholder {
    explicit Placeholder(const std::type_info& type_info_) noexcept
        : type_info(type_info_) {}
    Placeholder(const Placeholder&) = default;
    Placeholder(Placeholder&&) = default;
    virtual ~Placeholder() = default;
    virtual std::unique_ptr<Placeholder> clone() const {
      TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");
    }
    const std::type_info& type_info;
  };

  template <typename T>
  struct Holder : public Placeholder {
    template <typename U>
    // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
    explicit Holder(U&& value_) noexcept
        : Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
    std::unique_ptr<Placeholder> clone() const override {
      return std::make_unique<Holder<T>>(value);
    }
    T value;
  };

  std::unique_ptr<Placeholder> content_;
};

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources