Program Listing for File serialize.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/optim/serialize.h
)
#pragma once
#include <c10/util/irange.h>
#include <torch/optim/optimizer.h>
#include <torch/serialize/archive.h>
#include <torch/types.h>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <string>
#include <vector>
namespace torch::optim {
namespace detail {
// Utility function to save state
template <typename DerivedOptimizerParamState>
void serialize(
serialize::OutputArchive& archive,
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
state) {
for (const auto& item : state) {
serialize::OutputArchive param_state_archive(archive.compilation_unit());
std::string tensorimpl_key =
std::to_string(reinterpret_cast<size_t>(item.first));
const DerivedOptimizerParamState& curr_state =
static_cast<const DerivedOptimizerParamState&>(*(item.second));
curr_state.serialize(param_state_archive);
archive.write(tensorimpl_key, param_state_archive);
}
}
// Utility function to load state
template <typename DerivedOptimizerParamState>
void serialize(
serialize::InputArchive& archive,
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state) {
std::vector<std::string> tensorimpl_keys = archive.keys();
for (const std::string& tensorimpl_key : tensorimpl_keys) {
serialize::InputArchive param_state_archive;
archive.read(tensorimpl_key, param_state_archive);
DerivedOptimizerParamState param_state;
param_state.serialize(param_state_archive);
// NOLINTNEXTLINE(performance-no-int-to-ptr)
state[reinterpret_cast<void*>(std::stoull(tensorimpl_key))] =
std::make_unique<DerivedOptimizerParamState>(param_state);
}
}
// Utility function to save param_groups
template <typename DerivedOptimizerParamOptions>
void serialize(
serialize::OutputArchive& archive,
const std::vector<OptimizerParamGroup>& param_groups) {
archive.write(
"param_groups/size",
torch::tensor(static_cast<int64_t>(param_groups.size())));
for (const auto i : c10::irange(param_groups.size())) {
serialize::OutputArchive param_group_archive(archive.compilation_unit());
std::vector<Tensor> params = param_groups[i].params();
param_group_archive.write(
"params/size", torch::tensor(static_cast<int64_t>(params.size())));
for (const auto index : c10::irange(params.size())) {
param_group_archive.write(
"params/" + std::to_string(index),
IValue(std::to_string(
reinterpret_cast<size_t>(params[index].unsafeGetTensorImpl()))));
}
const DerivedOptimizerParamOptions& param_group_options =
static_cast<const DerivedOptimizerParamOptions&>(
param_groups[i].options());
serialize::OutputArchive param_group_options_archive(
param_group_archive.compilation_unit());
param_group_options.serialize(param_group_options_archive);
param_group_archive.write("options", param_group_options_archive);
archive.write("param_groups/" + std::to_string(i), param_group_archive);
}
}
// Utility function to load param_groups
// We take as input vector of pair of string and unique_ptr to optimizer options
// so that we can retain the state for each param by using the old tensor impl
// keys (saved during serialization) and map the new tensor impl keys to the
// correct state for each param
template <typename DerivedOptimizerParamOptions>
void serialize(
serialize::InputArchive& archive,
std::vector<
std::pair<std::vector<std::string>, std::unique_ptr<OptimizerOptions>>>&
param_groups) {
torch::Tensor param_groups_size_tensor;
archive.read("param_groups/size", param_groups_size_tensor);
const int64_t param_groups_size = param_groups_size_tensor.item<int64_t>();
for (const auto i : c10::irange(param_groups_size)) {
serialize::InputArchive param_group_archive;
archive.read("param_groups/" + std::to_string(i), param_group_archive);
torch::Tensor size_tensor;
param_group_archive.read("params/size", size_tensor);
const int64_t size = size_tensor.item<int64_t>();
std::vector<std::string> params;
for (const auto index : c10::irange(size)) {
IValue ivalue;
param_group_archive.read("params/" + std::to_string(index), ivalue);
std::string element = ivalue.toStringRef();
params.emplace_back(element);
}
serialize::InputArchive param_group_options_archive;
param_group_archive.read("options", param_group_options_archive);
DerivedOptimizerParamOptions param_group_options(0);
param_group_options.serialize(param_group_options_archive);
param_groups.emplace_back(std::make_pair(
params,
std::make_unique<DerivedOptimizerParamOptions>(param_group_options)));
}
}
} // namespace detail
// Note: These functions are all called `serialize()` so they can be called
// inside a template where the archive type is a template type and can thus be
// passed such that the appropriate overload is selected.
void serialize(
serialize::OutputArchive& archive,
const std::string& key,
const int64_t& value);
void serialize(
serialize::InputArchive& archive,
const std::string& key,
int64_t& value);
void serialize(
serialize::OutputArchive& archive,
const std::string& key,
const std::vector<int64_t>& steps);
void serialize(
serialize::InputArchive& archive,
const std::string& key,
std::vector<int64_t>& steps);
// Utility function to save state and param_groups
template <
typename DerivedOptimizerParamState,
typename DerivedOptimizerParamOptions>
void serialize(serialize::OutputArchive& archive, const Optimizer& optimizer) {
archive.write("pytorch_version", IValue("1.5.0"));
serialize::OutputArchive state_archive(archive.compilation_unit());
detail::serialize<DerivedOptimizerParamState>(
state_archive, optimizer.state());
archive.write("state", state_archive);
serialize::OutputArchive param_groups_archive(archive.compilation_unit());
detail::serialize<DerivedOptimizerParamOptions>(
param_groups_archive, optimizer.param_groups());
archive.write("param_groups", param_groups_archive);
}
// Utility function to load state and param_groups and update state
template <
typename DerivedOptimizerParamState,
typename DerivedOptimizerParamOptions>
void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
IValue pytorch_version;
archive.read("pytorch_version", pytorch_version);
TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0");
serialize::InputArchive state_archive;
archive.read("state", state_archive);
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> saved_state;
detail::serialize<DerivedOptimizerParamState>(state_archive, saved_state);
serialize::InputArchive param_groups_archive;
archive.read("param_groups", param_groups_archive);
std::vector<
std::pair<std::vector<std::string>, std::unique_ptr<OptimizerOptions>>>
saved_param_groups;
detail::serialize<DerivedOptimizerParamOptions>(
param_groups_archive, saved_param_groups);
// update state and optimizer options
TORCH_CHECK(
saved_param_groups.size() == optimizer.param_groups().size(),
"loaded state dict has a different number of parameter groups");
for (const auto i : c10::irange(saved_param_groups.size())) {
std::vector<std::string> param_group_old_keys = saved_param_groups[i].first;
std::vector<Tensor> params = optimizer.param_groups()[i].params();
TORCH_CHECK(
param_group_old_keys.size() == params.size(),
"loaded state dict contains a parameter group that has a different size than the optimizer's parameter group");
for (const auto idx : c10::irange(params.size())) {
auto param_group_old_key =
// NOLINTNEXTLINE(performance-no-int-to-ptr)
reinterpret_cast<void*>(std::stoull(param_group_old_keys[idx]));
if (saved_state.find(param_group_old_key) != saved_state.end()) {
optimizer.state()[params[idx].unsafeGetTensorImpl()] =
std::move(saved_state[param_group_old_key]);
}
}
auto& saved_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
*saved_param_groups[i].second);
auto& current_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
optimizer.param_groups()[i].options());
current_options = saved_options;
}
}
template <typename BufferContainer>
void serialize(
serialize::OutputArchive& archive,
const std::string& key,
const BufferContainer& buffers) {
archive.write(
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
for (const auto index : c10::irange(buffers.size())) {
archive.write(
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
}
}
template <typename BufferContainer>
void serialize(
serialize::InputArchive& archive,
const std::string& key,
BufferContainer& buffers) {
buffers.clear();
torch::Tensor size_tensor;
archive.read(key + "/size", size_tensor);
const size_t size = size_tensor.item<int64_t>();
for (const auto index : c10::irange(size)) {
buffers.emplace_back();
archive.read(
key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
}
}
template <typename T>
c10::List<T> deque_to_list(const std::deque<T>& dq) {
c10::List<T> list;
list.reserve(dq.size());
for (const auto& e : dq) {
list.emplace_back(e);
}
return list;
}
template <typename T>
std::deque<T> list_to_deque(const c10::List<T>& list) {
std::deque<T> dq;
for (const auto& e : list) {
dq.emplace_back(e);
}
return dq;
}
#define _TORCH_OPTIM_SERIALIZE(name) \
torch::optim::serialize(archive, #name, self.name)
#define _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(OptimizerName) \
torch::optim::serialize<OptimizerName##ParamState, OptimizerName##Options>( \
archive, self)
#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG(name) \
{ \
auto ivalue = torch::IValue(name()); \
/* do not serialize if name is an undefined tensor*/ \
if (!(ivalue.isTensor() && \
ivalue.unsafeToTensorImpl() == \
at::UndefinedTensorImpl::singleton())) { \
archive.write(#name, ivalue); \
} \
}
#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(name) \
{ \
c10::IValue ivalue = torch::IValue(deque_to_list(name())); \
archive.write(#name, ivalue); \
}
#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) \
{ \
c10::IValue ivalue; \
bool exists = archive.try_read(#name, ivalue); \
if (exists) { \
name(ivalue.to<T>()); \
} else { \
constexpr bool is_tensor_type = std::is_base_of_v<torch::Tensor, T>; \
TORCH_INTERNAL_ASSERT(is_tensor_type); \
} \
}
#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) \
{ \
c10::IValue ivalue; \
bool exists = archive.try_read(#name, ivalue); \
if (exists) { \
name(ivalue.toOptional<T>()); \
} \
}
#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(T, name) \
{ \
c10::IValue ivalue; \
archive.read(#name, ivalue); \
auto list = ivalue.to<c10::List<T::value_type>>(); \
name(list_to_deque(list)); \
}
} // namespace torch::optim