Shortcuts

Program Listing for File cloneable.h

Return to documentation for file (torch/csrc/api/include/torch/nn/cloneable.h)

#pragma once

#include <torch/nn/module.h>
#include <torch/types.h>
#include <torch/utils.h>

#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>

#include <memory>
#include <utility>

namespace torch {
namespace nn {
template <typename Derived>
// NOLINTNEXTLINE(bugprone-exception-escape)
class Cloneable : public Module {
 public:
  using Module::Module;

  virtual void reset() = 0;

  std::shared_ptr<Module> clone(
      const optional<Device>& device = nullopt) const override {
    NoGradGuard no_grad;

    const auto& self = static_cast<const Derived&>(*this);
    auto copy = std::make_shared<Derived>(self);
    copy->parameters_.clear();
    copy->buffers_.clear();
    copy->children_.clear();
    copy->reset();
    TORCH_CHECK(
        copy->parameters_.size() == parameters_.size(),
        "The cloned module does not have the same number of "
        "parameters as the original module after calling reset(). "
        "Are you sure you called register_parameter() inside reset() "
        "and not the constructor?");
    for (const auto& parameter : named_parameters(/*recurse=*/false)) {
      auto& tensor = *parameter;
      auto data = device && tensor.device() != *device
          ? tensor.to(*device)
          : autograd::Variable(tensor).clone();
      copy->parameters_[parameter.key()].set_data(data);
    }
    TORCH_CHECK(
        copy->buffers_.size() == buffers_.size(),
        "The cloned module does not have the same number of "
        "buffers as the original module after calling reset(). "
        "Are you sure you called register_buffer() inside reset() "
        "and not the constructor?");
    for (const auto& buffer : named_buffers(/*recurse=*/false)) {
      auto& tensor = *buffer;
      auto data = device && tensor.device() != *device
          ? tensor.to(*device)
          : autograd::Variable(tensor).clone();
      copy->buffers_[buffer.key()].set_data(data);
    }
    TORCH_CHECK(
        copy->children_.size() == children_.size(),
        "The cloned module does not have the same number of "
        "child modules as the original module after calling reset(). "
        "Are you sure you called register_module() inside reset() "
        "and not the constructor?");
    for (const auto& child : children_) {
      copy->children_[child.key()]->clone_(*child.value(), device);
    }
    return copy;
  }

 private:
  void clone_(Module& other, const optional<Device>& device) final {
    // Here we are *pretty* certain that `other's` type is `Derived` (because it
    // was registered under the same name as `this`), but you never know what
    // crazy things `reset()` does, so `dynamic_cast` just to be safe.
    auto clone = std::dynamic_pointer_cast<Derived>(other.clone(device));
    TORCH_CHECK(
        clone != nullptr,
        "Attempted to clone submodule, but it is of a "
        "different type than the submodule it was to be cloned into");
    static_cast<Derived&>(*this) = *clone;
  }
};

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources