Class Optimizer¶
Defined in File optimizer.h
Page Contents
Inheritance Relationships¶
Derived Types¶
public torch::optim::Adagrad
(Class Adagrad)public torch::optim::Adam
(Class Adam)public torch::optim::AdamW
(Class AdamW)public torch::optim::LBFGS
(Class LBFGS)public torch::optim::RMSprop
(Class RMSprop)public torch::optim::SGD
(Class SGD)
Class Documentation¶
-
class Optimizer¶
Subclassed by torch::optim::Adagrad, torch::optim::Adam, torch::optim::AdamW, torch::optim::LBFGS, torch::optim::RMSprop, torch::optim::SGD
Public Types
-
using LossClosure = std::function<Tensor()>¶
Public Functions
-
inline explicit Optimizer(const std::vector<OptimizerParamGroup> ¶m_groups, std::unique_ptr<OptimizerOptions> defaults)¶
-
inline explicit Optimizer(std::vector<Tensor> parameters, std::unique_ptr<OptimizerOptions> defaults)¶
Constructs the
Optimizer
from a vector of parameters.
-
void add_param_group(const OptimizerParamGroup ¶m_group)¶
Adds the given param_group to the optimizer’s param_group list.
-
virtual ~Optimizer() = default¶
-
virtual Tensor step(LossClosure closure = nullptr) = 0¶
A loss function closure, which is expected to return the loss value.
-
void add_parameters(const std::vector<Tensor> ¶meters)¶
Adds the given vector of parameters to the optimizer’s parameter list.
-
void zero_grad(bool set_to_none = true)¶
Zeros out the gradients of all parameters.
-
const std::vector<Tensor> ¶meters() const noexcept¶
Provides a const reference to the parameters in the first param_group this optimizer holds.
-
std::vector<Tensor> ¶meters() noexcept¶
Provides a reference to the parameters in the first param_group this optimizer holds.
-
size_t size() const noexcept¶
Returns the number of parameters referenced by the optimizer.
-
OptimizerOptions &defaults() noexcept¶
-
const OptimizerOptions &defaults() const noexcept¶
-
std::vector<OptimizerParamGroup> ¶m_groups() noexcept¶
Provides a reference to the param_groups this optimizer holds.
-
const std::vector<OptimizerParamGroup> ¶m_groups() const noexcept¶
Provides a const reference to the param_groups this optimizer holds.
-
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() noexcept¶
Provides a reference to the state this optimizer holds.
-
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() const noexcept¶
Provides a const reference to the state this optimizer holds.
Protected Attributes
-
std::vector<OptimizerParamGroup> param_groups_¶
-
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_¶
-
std::unique_ptr<OptimizerOptions> defaults_¶
-
using LossClosure = std::function<Tensor()>¶