Shortcuts

Program Listing for File data_parallel.h

Return to documentation for file (torch/csrc/api/include/torch/nn/parallel/data_parallel.h)

#pragma once

#include <torch/cuda.h>
#include <torch/nn/module.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <ATen/core/functional.h>
#include <torch/csrc/autograd/functions/comm.h>
#include <torch/csrc/autograd/functions/utils.h>

#include <ATen/Device.h>
#include <ATen/Parallel.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>

#include <cstddef>
#include <exception>
#include <memory>
#include <mutex>
#include <vector>

namespace torch {
namespace nn {

namespace {

// Note [Replicating Modules]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
//
// Module replication is implemented in the following two steps:
// 1) create a module replica on each destination device using Module.clone().
// 2) manually add a gradient edge pointing from every parameter X in every
//    module replica to the same parameter X in the original module, using
//    ReduceAdd as the grad_fn.
//
// ReduceAdd can ONLY be used during the backward pass of data parallel. Forward
// pass cannot use this function as it does not setup gradient function and
// history at all. Do NOT try to use ReduceAdd for any other purposes.
//
// NB: An alternative is to add Broadcast and ReduceAddCoalesce to
// torch/csrc/autograd/functions/comm.cpp as normal autograd functions,
// implement a Replicatable (like cloneable) class and add it as a friend class
// in Module.h. In the forward pass, the Replicatable could use the Broadcast
// function to replicate every module parameter and set gradient functions using
// ReduceAddCoalesce (like how it is implemented in Python). However, unlike in
// Python, where changes to Linear._parameters["weight"] would also apply to
// Linear.weight (using Linear as an example), Linear.weight and
// Linear.parameters_["weight"] are two tensor objects pointing to the same
// TensorImpl. Assigning a new tensor to Linear.parameters_["weight"] will not
// change Linear.weight. To make this work, we will have to:
// 1) force every module to also inherit from Replicatable
// 2) force every module to implement an additional function, e.g.,
//    Replicatable::load_params(), to pick up changes from parameters_ to their
//    own member fields.
// This will be an overkill as Replicatable will only be used in data_parallel,
// not even ddp.

// Autograd function for the replicate step in data parallel. This is only used
// in data parallel, and should not be exposed as a user API.
struct ReduceAdd : public autograd::Node {
  explicit ReduceAdd(const at::Device& destination_device)
      : destination_device_(destination_device){};
  ~ReduceAdd() override {}

  autograd::variable_list apply(autograd::variable_list&& inputs) override {
    TORCH_CHECK(
        !torch::autograd::compute_requires_grad(inputs),
        "ReduceAdd can only be used during the backward pass of data parallel.");

    Tensor output = torch::zeros_like(inputs[0], {destination_device_});

    for (auto& input : inputs) {
      TORCH_CHECK(
          input.sizes() == inputs[0].sizes(),
          "All inputs of ReduceAdd must have the same size, but got ",
          input.sizes(),
          " and ",
          inputs[0].sizes());

      TORCH_CHECK(
          input.dtype() == inputs[0].dtype(),
          "All inputs of ReduceAdd must have the same dtype, but got ",
          input.dtype(),
          " and ",
          inputs[0].dtype());

      // TODO: use nccl reduce
      output.add_(input.to(destination_device_));
    }

    return {output};
  }

 private:
  at::Device destination_device_;
};

} // namespace

// A friend function to Module, it recursively sets gradient edges pointing from
// every parameter X in every module replica to the same parameter X in the
// original module. See [Replicating Modules]
template <typename ModuleType>
void replicate_grad_edges(
    const std::shared_ptr<Module>& module,
    const std::vector<std::shared_ptr<ModuleType>>& replicas,
    const std::vector<Device>& devices) {
  for (auto& parameter : module->named_parameters(/*recurse=*/false)) {
    auto grad_fn = std::make_shared<ReduceAdd>((*parameter).device());
    grad_fn->set_next_edges(autograd::collect_next_edges(*parameter));

    for (const auto i : c10::irange(devices.size())) {
      autograd::set_history(replicas[i]->parameters_[parameter.key()], grad_fn);
    }
  }

  for (auto& buffer : module->named_buffers(/*recurse=*/false)) {
    if (buffer.value().requires_grad()) {
      auto grad_fn = std::make_shared<ReduceAdd>((*buffer).device());
      grad_fn->set_next_edges(autograd::collect_next_edges(*buffer));

      for (const auto i : c10::irange(devices.size())) {
        autograd::set_history(replicas[i]->buffers_[buffer.key()], grad_fn);
      }
    }
  }

  for (auto& child : module->children_) {
    std::vector<std::shared_ptr<Module>> child_replicas;
    child_replicas.reserve(devices.size());
    for (auto& replica : replicas) {
      child_replicas.push_back(replica->children_[child.key()]);
    }

    // recursively set gradient edges for all children
    replicate_grad_edges(*child, child_replicas, devices);
  }
}

namespace parallel {

template <typename ModuleType>
std::vector<std::shared_ptr<ModuleType>> replicate(
    const std::shared_ptr<ModuleType>& module,
    const std::vector<Device>& devices) {
  std::vector<std::shared_ptr<ModuleType>> replicas;
  replicas.reserve(devices.size());
  for (const auto& device : devices) {
    replicas.push_back(
        std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
  }
  // Configure gradient edges to point from replcia parameters to original
  // module parameters. See [Replicating Modules]
  replicate_grad_edges(module, replicas, devices);
  return replicas;
}

template <typename ModuleType>
std::vector<ModuleHolder<ModuleType>> replicate(
    const ModuleHolder<ModuleType>& module,
    const std::vector<Device>& devices) {
  auto ptrs = replicate(module.ptr(), devices);
  return std::vector<ModuleHolder<ModuleType>>(ptrs.begin(), ptrs.end());
}

template <typename ModuleType>
std::vector<Tensor> parallel_apply(
    std::vector<ModuleType>& modules,
    const std::vector<Tensor>& inputs,
    const optional<std::vector<Device>>& devices = nullopt) {
  TORCH_CHECK(
      modules.size() == inputs.size(), "Must have as many inputs as modules");
  if (devices) {
    TORCH_CHECK(
        modules.size() == devices->size(),
        "Must have as many devices as modules");
  }

  std::vector<Tensor> outputs(modules.size());
  std::mutex mutex;

  // std::exception_ptr can be passed between threads:
  // > An instance of std::exception_ptr may be passed to another function,
  // > possibly on another thread, where the exception may be rethrown [...].
  // https://en.cppreference.com/w/cpp/error/exception_ptr
  std::exception_ptr exception;

  at::parallel_for(
      /*begin=*/0,
      /*end=*/modules.size(),
      /*grain_size=*/1,
      [&modules, &inputs, &devices, &outputs, &mutex, &exception](
          int64_t index, int64_t stop) {
        for (; index < stop; ++index) {
          try {
            auto output = modules[index]->forward(inputs[index]);
            output =
                output.to(devices ? (*devices)[index] : inputs[index].device());
            std::lock_guard<std::mutex> lock(mutex);
            outputs[index] = output;
          } catch (...) {
            std::lock_guard<std::mutex> lock(mutex);
            if (!exception) {
              exception = std::current_exception();
            }
          }
        }
      });

  if (exception) {
    std::rethrow_exception(exception);
  }

  return outputs;
}

template <typename ModuleType>
Tensor data_parallel(
    ModuleType module,
    Tensor input,
    optional<std::vector<Device>> devices = nullopt,
    optional<Device> output_device = nullopt,
    int64_t dim = 0) {
  if (!devices) {
    const auto device_count = torch::cuda::device_count();
    TORCH_CHECK(
        device_count > 0, "Expected at least one CUDA device to be available");
    devices = std::vector<Device>();
    devices->reserve(device_count);
    for (const auto index : c10::irange(device_count)) {
      devices->emplace_back(kCUDA, static_cast<torch::DeviceIndex>(index));
    }
  }
  if (!output_device) {
    output_device = devices->front();
  }

  if (devices->size() == 1) {
    module->to(devices->front());
    input = input.to(devices->front());
    return module->forward(std::move(input)).to(*output_device);
  }

  autograd::Scatter scatter(*devices, /*chunk_sizes=*/nullopt, dim);
  auto scattered_inputs = fmap<Tensor>(scatter.apply({std::move(input)}));
  // Input tensor might not be big enough to scale across all available devices
  if (scattered_inputs.size() < devices->size()) {
    devices->resize(
        scattered_inputs.size(),
        Device(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES));
  }

  auto replicas = replicate(module, *devices);
  auto outputs = parallel_apply(replicas, scattered_inputs, *devices);
  return autograd::Gather(*output_device, dim)
      .apply(fmap<autograd::Variable>(std::move(outputs)))
      .front();
}

} // namespace parallel
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources