Shortcuts

Program Listing for File convert_parameters.h

Return to documentation for file (torch/csrc/api/include/torch/nn/utils/convert_parameters.h)

#pragma once

#include <torch/csrc/Export.h>
#include <torch/types.h>

namespace torch {
namespace nn {
namespace utils {

// This helper function is to check if the parameters are located
// in the same device. Currently, the conversion between model parameters
// and single vector form is not supported for multiple allocations,
// e.g. parameters in different GPUs, or mixture of CPU/GPU.
inline c10::optional<int64_t> _check_param_device(
    const torch::Tensor& param,
    c10::optional<int64_t> old_param_device) {
  // Meet the first parameter
  if (old_param_device == c10::nullopt) {
    old_param_device = param.is_cuda() ? param.get_device() : -1;
  } else {
    bool warn = false;
    if (param.is_cuda()) { // Check if in same GPU
      warn = (param.get_device() != old_param_device.value());
    } else { // Check if in CPU
      warn = (old_param_device.value() != -1);
    }
    if (warn) {
      TORCH_CHECK(
          false,
          "Found two parameters on different devices, ",
          "this is currently not supported.");
    }
  }

  return old_param_device;
}

// Convert parameters to one vector
inline torch::Tensor parameters_to_vector(
    const std::vector<torch::Tensor>& parameters) {
  c10::optional<int64_t> param_device;

  std::vector<torch::Tensor> vec;
  vec.reserve(parameters.size());

  for (const torch::Tensor& param : parameters) {
    // Ensure the parameters are located in the same device
    param_device = _check_param_device(param, param_device);

    vec.push_back(param.view(-1));
  }

  return torch::cat(vec);
}

// Convert one vector to the parameters
inline void vector_to_parameters(
    const torch::Tensor& vec,
    const std::vector<torch::Tensor>& parameters) {
  // Flag for the device where the parameter is located
  c10::optional<int64_t> param_device;

  // Pointer for slicing the vector for each parameter
  int64_t pointer = 0;
  for (const torch::Tensor& param : parameters) {
    // Ensure the parameters are located in the same device
    param_device = _check_param_device(param, param_device);

    // The length of the parameter
    auto num_param = param.numel();
    // Slice the vector, reshape it, and replace the old data of the parameter
    param.set_data(
        vec.slice(0, pointer, pointer + num_param).view_as(param).data());

    // Increment the pointer
    pointer += num_param;
  }
}

} // namespace utils
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources