Program Listing for File optimizer.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/optim/optimizer.h
)
#pragma once
#include <ATen/Tensor.h>
#include <c10/util/Exception.h>
#include <c10/util/flat_hash_map.h>
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <vector>
// Forward declarations confuse Doxygen
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace at {
class Tensor;
} // namespace at
namespace torch {
using at::Tensor;
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
#endif // DOXYGEN_SHOULD_SKIP_THIS
namespace torch::optim {
class TORCH_API OptimizerParamState {
public:
OptimizerParamState() = default;
OptimizerParamState(const OptimizerParamState&) = default;
OptimizerParamState& operator=(const OptimizerParamState&) = default;
OptimizerParamState(OptimizerParamState&&) noexcept = default;
OptimizerParamState& operator=(OptimizerParamState&&) noexcept = default;
virtual std::unique_ptr<OptimizerParamState> clone() const;
virtual void serialize(torch::serialize::InputArchive& archive);
virtual void serialize(torch::serialize::OutputArchive& archive) const;
virtual ~OptimizerParamState() = default;
};
template <typename Derived>
class OptimizerCloneableParamState : public OptimizerParamState {
std::unique_ptr<OptimizerParamState> clone() const override {
return std::make_unique<Derived>(static_cast<const Derived&>(*this));
}
};
class TORCH_API OptimizerOptions {
public:
OptimizerOptions() = default;
OptimizerOptions(const OptimizerOptions&) = default;
OptimizerOptions& operator=(const OptimizerOptions&) = default;
OptimizerOptions(OptimizerOptions&&) noexcept = default;
OptimizerOptions& operator=(OptimizerOptions&&) noexcept = default;
virtual std::unique_ptr<OptimizerOptions> clone() const;
virtual void serialize(torch::serialize::InputArchive& archive);
virtual void serialize(torch::serialize::OutputArchive& archive) const;
virtual ~OptimizerOptions() = default;
virtual double get_lr() const;
virtual void set_lr(const double lr);
};
template <typename Derived>
class OptimizerCloneableOptions : public OptimizerOptions {
private:
std::unique_ptr<OptimizerOptions> clone() const override {
return std::make_unique<Derived>(static_cast<const Derived&>(*this));
}
};
class TORCH_API OptimizerParamGroup {
public:
// NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to
// be copy-constructible.
OptimizerParamGroup(const OptimizerParamGroup& param_group)
: params_(param_group.params()),
options_(
param_group.has_options() ? param_group.options().clone()
: nullptr) {}
OptimizerParamGroup(std::vector<Tensor> params)
: params_(std::move(params)) {}
OptimizerParamGroup(
std::vector<Tensor> params,
std::unique_ptr<OptimizerOptions> options)
: params_(std::move(params)), options_(std::move(options)) {}
OptimizerParamGroup& operator=(const OptimizerParamGroup& param_group) =
delete;
bool has_options() const;
OptimizerOptions& options();
const OptimizerOptions& options() const;
void set_options(std::unique_ptr<OptimizerOptions> options);
std::vector<Tensor>& params();
const std::vector<Tensor>& params() const;
protected:
std::vector<Tensor> params_;
std::unique_ptr<OptimizerOptions> options_;
};
class TORCH_API Optimizer {
public:
// The copy constructor is deleted, because the user should use the
// `state_dict` / `load_state_dict` API to copy an optimizer instead.
Optimizer(const Optimizer& optimizer) = delete;
Optimizer(Optimizer&& optimizer) = default;
explicit Optimizer(
const std::vector<OptimizerParamGroup>& param_groups,
std::unique_ptr<OptimizerOptions> defaults)
: defaults_(std::move(defaults)) {
for (const auto& param_group : param_groups) {
add_param_group(param_group);
}
}
explicit Optimizer(
std::vector<Tensor> parameters,
std::unique_ptr<OptimizerOptions> defaults)
: Optimizer(
{OptimizerParamGroup(std::move(parameters))},
std::move(defaults)) {}
void add_param_group(const OptimizerParamGroup& param_group);
virtual ~Optimizer() = default;
using LossClosure = std::function<Tensor()>;
virtual Tensor step(LossClosure closure = nullptr) = 0;
void add_parameters(const std::vector<Tensor>& parameters);
void zero_grad(bool set_to_none = true);
const std::vector<Tensor>& parameters() const noexcept;
std::vector<Tensor>& parameters() noexcept;
size_t size() const noexcept;
OptimizerOptions& defaults() noexcept;
const OptimizerOptions& defaults() const noexcept;
std::vector<OptimizerParamGroup>& param_groups() noexcept;
const std::vector<OptimizerParamGroup>& param_groups() const noexcept;
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
state() noexcept;
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state()
const noexcept;
virtual void save(serialize::OutputArchive& archive) const;
virtual void load(serialize::InputArchive& archive);
protected:
std::vector<OptimizerParamGroup> param_groups_;
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;
std::unique_ptr<OptimizerOptions> defaults_;
};
/* How do we decide whether to serialize undefined tensors or
std::nullopt values into the output archive?
Answer: we strictly follow the behavior of Python API. To be more specific:
For optimizer options:
a) For undefined tensor: currently no tensor is used as an options argument in
Python API, so we don't need to worry about it now. b) For std::nullopt value:
we serialize std::nullopt values into the output archive, to follow the exact
same behavior as Python API.
For optimizer param state:
a) For undefined tensor: in param state, undefined tensor in C++ impl is
equivalent to missing key in Python impl. Since we don't serialize missing keys
in Python API, we skip undefined tensors when serializing the param state. b)
For std::nullopt value: in param state, std::nullopt value in C++ impl is
equivalent to missing key in Python impl. Since we don't serialize missing keys
in Python API, we skip std::nullopt values when serializing the param state. */
TORCH_API serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const Optimizer& optimizer);
TORCH_API serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
Optimizer& optimizer);
} // namespace torch::optim