#pragma once

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/types.h>

namespace torch {
namespace nn {
namespace utils {

// This helper function is to check if the parameters are located
// in the same device. Currently, the conversion between model parameters
// and single vector form is not supported for multiple allocations,
// e.g. parameters in different GPUs, or mixture of CPU/GPU.
inline c10::optional<int64_t> _check_param_device(const torch::Tensor& param, c10::optional<int64_t> old_param_device) {
  // Meet the first parameter
  if (old_param_device == c10::nullopt) {
    old_param_device = param.is_cuda() ? param.get_device() : -1;
  else {
    bool warn = false;
    if (param.is_cuda()) { // Check if in same GPU
      warn = (param.get_device() != old_param_device.value());
    else { // Check if in CPU
      warn = (old_param_device.value() != -1);
    if (warn) {
      TORCH_CHECK(false, "Found two parameters on different devices, ",
                         "this is currently not supported.");

  return old_param_device;

// Convert parameters to one vector
inline torch::Tensor parameters_to_vector(const std::vector<torch::Tensor>& parameters) {
  c10::optional<int64_t> param_device;

  std::vector<torch::Tensor> vec;

  for (const torch::Tensor& param : parameters) {
    // Ensure the parameters are located in the same device
    param_device = _check_param_device(param, param_device);


  return torch::cat(vec);

// Convert one vector to the parameters
inline void vector_to_parameters(const torch::Tensor& vec, std::vector<torch::Tensor> parameters) {
  // Flag for the device where the parameter is located
  c10::optional<int64_t> param_device;

  // Pointer for slicing the vector for each parameter
  int64_t pointer = 0;
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  int64_t num_param;
  for (torch::Tensor& param : parameters) {
    // Ensure the parameters are located in the same device
    param_device = _check_param_device(param, param_device);

    // The length of the parameter
    num_param = param.numel();
    // Slice the vector, reshape it, and replace the old data of the parameter
    param.set_data(vec.slice(0, pointer, pointer + num_param).view_as(param).data());

    // Increment the pointer
    pointer += num_param;

} // namespace utils
} // namespace nn
} // namespace torch


