Program Listing for File cloneable.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/cloneable.h
)
#pragma once
#include <torch/nn/module.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <memory>
#include <utility>
namespace torch::nn {
template <typename Derived>
// NOLINTNEXTLINE(bugprone-exception-escape)
class Cloneable : public Module {
public:
using Module::Module;
virtual void reset() = 0;
std::shared_ptr<Module> clone(
const std::optional<Device>& device = std::nullopt) const override {
NoGradGuard no_grad;
const auto& self = static_cast<const Derived&>(*this);
auto copy = std::make_shared<Derived>(self);
copy->parameters_.clear();
copy->buffers_.clear();
copy->children_.clear();
copy->reset();
TORCH_CHECK(
copy->parameters_.size() == parameters_.size(),
"The cloned module does not have the same number of "
"parameters as the original module after calling reset(). "
"Are you sure you called register_parameter() inside reset() "
"and not the constructor?");
for (const auto& parameter : named_parameters(/*recurse=*/false)) {
auto& tensor = *parameter;
auto data = device && tensor.device() != *device ? tensor.to(*device)
: tensor.clone();
copy->parameters_[parameter.key()].set_data(data);
}
TORCH_CHECK(
copy->buffers_.size() == buffers_.size(),
"The cloned module does not have the same number of "
"buffers as the original module after calling reset(). "
"Are you sure you called register_buffer() inside reset() "
"and not the constructor?");
for (const auto& buffer : named_buffers(/*recurse=*/false)) {
auto& tensor = *buffer;
auto data = device && tensor.device() != *device ? tensor.to(*device)
: tensor.clone();
copy->buffers_[buffer.key()].set_data(data);
}
TORCH_CHECK(
copy->children_.size() == children_.size(),
"The cloned module does not have the same number of "
"child modules as the original module after calling reset(). "
"Are you sure you called register_module() inside reset() "
"and not the constructor?");
for (const auto& child : children_) {
copy->children_[child.key()]->clone_(*child.value(), device);
}
return copy;
}
private:
void clone_(Module& other, const std::optional<Device>& device) final {
// Here we are *pretty* certain that `other's` type is `Derived` (because it
// was registered under the same name as `this`), but you never know what
// crazy things `reset()` does, so `dynamic_cast` just to be safe.
auto clone = std::dynamic_pointer_cast<Derived>(other.clone(device));
TORCH_CHECK(
clone != nullptr,
"Attempted to clone submodule, but it is of a "
"different type than the submodule it was to be cloned into");
static_cast<Derived&>(*this) = *clone;
}
};
} // namespace torch::nn