Program Listing for File serialize.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/serialize.h
)
#pragma once
#include <c10/util/irange.h>
#include <torch/csrc/Export.h>
#include <torch/serialize/archive.h>
#include <torch/serialize/tensor.h>
#include <utility>
namespace torch {
template <typename Value, typename... SaveToArgs>
void save(const Value& value, SaveToArgs&&... args) {
serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>());
archive << value;
archive.save_to(std::forward<SaveToArgs>(args)...);
}
template <typename... SaveToArgs>
void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>());
for (const auto i : c10::irange(tensor_vec.size())) {
auto& value = tensor_vec[i];
archive.write(std::to_string(i), value);
}
archive.save_to(std::forward<SaveToArgs>(args)...);
}
TORCH_API std::vector<char> pickle_save(const torch::IValue& ivalue);
TORCH_API torch::IValue pickle_load(const std::vector<char>& data);
template <typename Value, typename... LoadFromArgs>
void load(Value& value, LoadFromArgs&&... args) {
serialize::InputArchive archive;
archive.load_from(std::forward<LoadFromArgs>(args)...);
archive >> value;
}
template <typename... LoadFromArgs>
void load(std::vector<torch::Tensor>& tensor_vec, LoadFromArgs&&... args) {
serialize::InputArchive archive;
archive.load_from(std::forward<LoadFromArgs>(args)...);
// NOTE: The number of elements in the serialized `std::vector<torch::Tensor>`
// is not known ahead of time, so we need a while-loop to increment the index,
// and use `archive.try_read(...)` to check whether we have reached the end of
// the serialized `std::vector<torch::Tensor>`.
size_t index = 0;
torch::Tensor value;
while (archive.try_read(std::to_string(index), value)) {
tensor_vec.push_back(std::move(value));
value = torch::Tensor();
index++;
}
}
} // namespace torch