Shortcuts

Program Listing for File optimizer.h

Return to documentation for file (torch/csrc/api/include/torch/optim/optimizer.h)

#pragma once

#include <ATen/Tensor.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/Exception.h>

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/arg.h>

#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <vector>

// Forward declarations confuse Doxygen
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace at {
class Tensor;
} // namespace at

namespace torch {
using at::Tensor;
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
#endif // DOXYGEN_SHOULD_SKIP_THIS

namespace torch {
namespace optim {

class TORCH_API OptimizerParamState {
 public:
  virtual std::unique_ptr<OptimizerParamState> clone() const;
  virtual void serialize(torch::serialize::InputArchive& archive);
  virtual void serialize(torch::serialize::OutputArchive& archive) const;
  virtual ~OptimizerParamState() = default;
};

template <typename Derived>
class OptimizerCloneableParamState : public OptimizerParamState {
  std::unique_ptr<OptimizerParamState> clone() const override {
    return std::make_unique<Derived>(static_cast<const Derived&>(*this));
  }
};

class TORCH_API OptimizerOptions {
 public:
  virtual std::unique_ptr<OptimizerOptions> clone() const;
  virtual void serialize(torch::serialize::InputArchive& archive);
  virtual void serialize(torch::serialize::OutputArchive& archive) const;
  virtual ~OptimizerOptions() = default;
  virtual double get_lr() const;
  virtual void set_lr(const double lr);
};

template <typename Derived>
class OptimizerCloneableOptions : public OptimizerOptions {
private:
  std::unique_ptr<OptimizerOptions> clone() const override {
    return std::make_unique<Derived>(static_cast<const Derived&>(*this));
  }
};

class TORCH_API OptimizerParamGroup {
 public:
  // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to be copy-constructible.
  OptimizerParamGroup(const OptimizerParamGroup& param_group) : params_(param_group.params()), options_(param_group.has_options() ? param_group.options().clone() : nullptr) {}
  OptimizerParamGroup(std::vector<Tensor> params) : params_(std::move(params)) {}
  OptimizerParamGroup(std::vector<Tensor> params, std::unique_ptr<OptimizerOptions> options) : params_(std::move(params)), options_(std::move(options)) {}

  bool has_options() const;
  OptimizerOptions& options();
  const OptimizerOptions& options() const;
  void set_options(std::unique_ptr<OptimizerOptions> options);
  std::vector<Tensor>& params();
  const std::vector<Tensor>& params() const;

 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<Tensor> params_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::unique_ptr<OptimizerOptions> options_;
};

class TORCH_API Optimizer {
 public:
  // The copy constructor is deleted, because the user should use the
  // `state_dict` / `load_state_dict` API to copy an optimizer instead.
  Optimizer(const Optimizer& optimizer) = delete;
  Optimizer(Optimizer&& optimizer) = default;

  explicit Optimizer(std::vector<OptimizerParamGroup> param_groups, std::unique_ptr<OptimizerOptions> defaults) : defaults_(std::move(defaults)) {
    for (const auto& param_group : param_groups) {
      add_param_group(param_group);
    }
  }

  // NOLINTNEXTLINE(performance-move-const-arg)
  explicit Optimizer(std::vector<Tensor> parameters, std::unique_ptr<OptimizerOptions> defaults) : Optimizer({std::move(OptimizerParamGroup(parameters))}, std::move(defaults)) {};

  void add_param_group(const OptimizerParamGroup& param_group);

  virtual ~Optimizer() = default;

  using LossClosure = std::function<Tensor()>;
  virtual Tensor step(LossClosure closure = nullptr) = 0;

  void add_parameters(const std::vector<Tensor>& parameters);

  void zero_grad();

  const std::vector<Tensor>& parameters() const noexcept;

  std::vector<Tensor>& parameters() noexcept;

  size_t size() const noexcept;

  OptimizerOptions& defaults() noexcept;

  const OptimizerOptions& defaults() const noexcept;

  std::vector<OptimizerParamGroup>& param_groups() noexcept;

  const std::vector<OptimizerParamGroup>& param_groups() const noexcept;

  ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>& state() noexcept;

  const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>& state() const noexcept;

  virtual void save(serialize::OutputArchive& archive) const;

  virtual void load(serialize::InputArchive& archive);

 protected:
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
   std::vector<OptimizerParamGroup> param_groups_;
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
   ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>> state_;
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
   std::unique_ptr<OptimizerOptions> defaults_;
};

/* How do we decide whether to serialize undefined tensors or
  c10::nullopt values into the output archive?
Answer: we strictly follow the behavior of Python API. To be more specific:

For optimizer options:
a) For undefined tensor: currently no tensor is used as an options argument in Python API,
   so we don't need to worry about it now.
b) For c10::nullopt value: we serialize c10::nullopt values into the output archive,
   to follow the exact same behavior as Python API.

For optimizer param state:
a) For undefined tensor: in param state, undefined tensor in C++ impl is equivalent to
   missing key in Python impl. Since we don't serialize missing keys in Python API,
   we skip undefined tensors when serializing the param state.
b) For c10::nullopt value: in param state, c10::nullopt value in C++ impl is equivalent to
   missing key in Python impl. Since we don't serialize missing keys in Python API,
   we skip c10::nullopt values when serializing the param state. */

TORCH_API serialize::OutputArchive& operator<<(
    serialize::OutputArchive& archive,
    const Optimizer& optimizer);

TORCH_API serialize::InputArchive& operator>>(
    serialize::InputArchive& archive,
    Optimizer& optimizer);

} // namespace optim
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources