Shortcuts

Class Optimizer

Inheritance Relationships

Derived Types

Class Documentation

class Optimizer

Subclassed by torch::optim::Adagrad, torch::optim::Adam, torch::optim::AdamW, torch::optim::LBFGS, torch::optim::RMSprop, torch::optim::SGD

Public Types

using LossClosure = std::function<Tensor()>

Public Functions

Optimizer(const Optimizer &optimizer) = delete
Optimizer(Optimizer &&optimizer) = default
inline explicit Optimizer(const std::vector<OptimizerParamGroup> &param_groups, std::unique_ptr<OptimizerOptions> defaults)
inline explicit Optimizer(std::vector<Tensor> parameters, std::unique_ptr<OptimizerOptions> defaults)

Constructs the Optimizer from a vector of parameters.

void add_param_group(const OptimizerParamGroup &param_group)

Adds the given param_group to the optimizer’s param_group list.

virtual ~Optimizer() = default
virtual Tensor step(LossClosure closure = nullptr) = 0

A loss function closure, which is expected to return the loss value.

void add_parameters(const std::vector<Tensor> &parameters)

Adds the given vector of parameters to the optimizer’s parameter list.

void zero_grad(bool set_to_none = true)

Zeros out the gradients of all parameters.

const std::vector<Tensor> &parameters() const noexcept

Provides a const reference to the parameters in the first param_group this optimizer holds.

std::vector<Tensor> &parameters() noexcept

Provides a reference to the parameters in the first param_group this optimizer holds.

size_t size() const noexcept

Returns the number of parameters referenced by the optimizer.

OptimizerOptions &defaults() noexcept
const OptimizerOptions &defaults() const noexcept
std::vector<OptimizerParamGroup> &param_groups() noexcept

Provides a reference to the param_groups this optimizer holds.

const std::vector<OptimizerParamGroup> &param_groups() const noexcept

Provides a const reference to the param_groups this optimizer holds.

ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() noexcept

Provides a reference to the state this optimizer holds.

const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() const noexcept

Provides a const reference to the state this optimizer holds.

virtual void save(serialize::OutputArchive &archive) const

Serializes the optimizer state into the given archive.

virtual void load(serialize::InputArchive &archive)

Deserializes the optimizer state from the given archive.

Protected Attributes

std::vector<OptimizerParamGroup> param_groups_
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_
std::unique_ptr<OptimizerOptions> defaults_

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources