Class Optimizer
Defined in File optimizer.h
Page Contents
Inheritance Relationships
Derived Types
public torch::optim::Adagrad
(Class Adagrad)public torch::optim::Adam
(Class Adam)public torch::optim::AdamW
(Class AdamW)public torch::optim::LBFGS
(Class LBFGS)public torch::optim::RMSprop
(Class RMSprop)public torch::optim::SGD
(Class SGD)
Class Documentation
-
class Optimizer
Subclassed by torch::optim::Adagrad, torch::optim::Adam, torch::optim::AdamW, torch::optim::LBFGS, torch::optim::RMSprop, torch::optim::SGD
Public Functions
-
Optimizer(const Optimizer &optimizer) = delete
-
Optimizer(Optimizer &&optimizer) = default
-
inline explicit Optimizer(const std::vector<OptimizerParamGroup> ¶m_groups, std::unique_ptr<OptimizerOptions> defaults)
-
inline explicit Optimizer(std::vector<Tensor> parameters, std::unique_ptr<OptimizerOptions> defaults)
Constructs the
Optimizer
from a vector of parameters.
-
void add_param_group(const OptimizerParamGroup ¶m_group)
Adds the given param_group to the optimizer’s param_group list.
-
virtual Tensor step(LossClosure closure = nullptr) = 0
A loss function closure, which is expected to return the loss value.
-
void add_parameters(const std::vector<Tensor> ¶meters)
Adds the given vector of parameters to the optimizer’s parameter list.
-
const std::vector<Tensor> ¶meters() const noexcept
Provides a const reference to the parameters in the first param_group this optimizer holds.
-
std::vector<Tensor> ¶meters() noexcept
Provides a reference to the parameters in the first param_group this optimizer holds.
-
OptimizerOptions &defaults() noexcept
-
const OptimizerOptions &defaults() const noexcept
-
std::vector<OptimizerParamGroup> ¶m_groups() noexcept
Provides a reference to the param_groups this optimizer holds.
-
const std::vector<OptimizerParamGroup> ¶m_groups() const noexcept
Provides a const reference to the param_groups this optimizer holds.
-
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() noexcept
Provides a reference to the state this optimizer holds.
-
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() const noexcept
Provides a const reference to the state this optimizer holds.
-
virtual void save(serialize::OutputArchive &archive) const
Serializes the optimizer state into the given
archive
.
-
virtual void load(serialize::InputArchive &archive)
Deserializes the optimizer state from the given
archive
.
Protected Attributes
-
std::vector<OptimizerParamGroup> param_groups_
-
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_
-
std::unique_ptr<OptimizerOptions> defaults_
-
Optimizer(const Optimizer &optimizer) = delete