Program Listing for File stateful.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/datasets/stateful.h
)
#pragma once
#include <torch/data/datasets/base.h>
#include <torch/data/example.h>
#include <cstddef>
#include <vector>
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
namespace torch {
namespace data {
namespace datasets {
template <
typename Self,
typename Batch = std::vector<Example<>>,
typename BatchRequest = size_t>
class StatefulDataset
: public BatchDataset<Self, optional<Batch>, BatchRequest> {
public:
virtual void reset() = 0;
virtual void save(serialize::OutputArchive& archive) const = 0;
virtual void load(serialize::InputArchive& archive) = 0;
};
template <typename... Args>
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const StatefulDataset<Args...>& statefulDataset) {
statefulDataset.save(archive);
return archive;
}
template <typename... Args>
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
StatefulDataset<Args...>& statefulDataset) {
statefulDataset.load(archive);
return archive;
}
} // namespace datasets
} // namespace data
} // namespace torch