
Program Listing for File data_parallel.h

Return to documentation for file (torch/csrc/api/include/torch/nn/parallel/data_parallel.h)

#pragma once

#include <torch/cuda.h>
#include <torch/nn/module.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <ATen/core/functional.h>
#include <torch/csrc/autograd/functions/comm.h>
#include <torch/csrc/autograd/functions/utils.h>

#include <ATen/Device.h>
#include <ATen/Parallel.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>

#include <cstddef>
#include <exception>
#include <memory>
#include <mutex>
#include <vector>

namespace torch {
namespace nn {

namespace {

// Note [Replicating Modules]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// Module replication is implemented in the following two steps:
// 1) create a module replica on each destination device using Module.clone().
// 2) manually add a gradient edge pointing from every parameter X in every
//    module replica to the same parameter X in the original module, using
//    ReduceAdd as the grad_fn.
// ReduceAdd can ONLY be used during the backward pass of data parallel. Forward
// pass cannot use this function as it does not setup gradient function and
// history at all. Do NOT try to use ReduceAdd for any other purposes.
// NB: An alternative is to add Broadcast and ReduceAddCoalesce to
// torch/csrc/autograd/functions/comm.cpp as normal autograd functions,
// implement a Replicatable (like cloneable) class and add it as a friend class
// in Module.h. In the forward pass, the Replicatable could use the Broadcast
// function to replicate every module parameter and set gradient functions using
// ReduceAddCoalesce (like how it is implemented in Python). However, unlike in
// Python, where changes to Linear._parameters["weight"] would also apply to
// Linear.weight (using Linear as an example), Linear.weight and
// Linear.parameters_["weight"] are two tensor objects pointing to the same
// TensorImpl. Assigning a new tensor to Linear.parameters_["weight"] will not
// change Linear.weight. To make this work, we will have to:
// 1) force every module to also inherit from Replicatable
// 2) force every module to implement an additional function, e.g.,
//    Replicatable::load_params(), to pick up changes from parameters_ to their
//    own member fields.
// This will be an overkill as Replicatable will only be used in data_parallel,
// not even ddp.

// Autograd function for the replicate step in data parallel. This is only used
// in data parallel, and should not be exposed as a user API.
struct ReduceAdd : public autograd::Node {
  explicit ReduceAdd(const at::Device& destination_device)
      : destination_device_(destination_device){};
  ~ReduceAdd() override {}

  autograd::variable_list apply(autograd::variable_list&& inputs) override {
        "ReduceAdd can only be used during the backward pass of data parallel.");

    Tensor output = torch::zeros_like(inputs[0], {destination_device_});

    for (auto& input : inputs) {
          input.sizes() == inputs[0].sizes(),
          "All inputs of ReduceAdd must have the same size, but got ",
          " and ",

          input.dtype() == inputs[0].dtype(),
          "All inputs of ReduceAdd must have the same dtype, but got ",
          " and ",

      // TODO: use nccl reduce

    return {output};

  at::Device destination_device_;

} // namespace

// A friend function to Module, it recursively sets gradient edges pointing from
// every parameter X in every module replica to the same parameter X in the
// original module. See [Replicating Modules]
template <typename ModuleType>
void replicate_grad_edges(
    const std::shared_ptr<Module>& module,
    const std::vector<std::shared_ptr<ModuleType>>& replicas,
    const std::vector<Device>& devices) {
  for (auto& parameter : module->named_parameters(/*recurse=*/false)) {
    auto grad_fn = std::make_shared<ReduceAdd>((*parameter).device());

    for (const auto i : c10::irange(devices.size())) {
      autograd::set_history(replicas[i]->parameters_[parameter.key()], grad_fn);

  for (auto& buffer : module->named_buffers(/*recurse=*/false)) {
    if (buffer.value().requires_grad()) {
      auto grad_fn = std::make_shared<ReduceAdd>((*buffer).device());

      for (const auto i : c10::irange(devices.size())) {
        autograd::set_history(replicas[i]->buffers_[buffer.key()], grad_fn);

  for (auto& child : module->children_) {
    std::vector<std::shared_ptr<Module>> child_replicas;
    for (auto& replica : replicas) {

    // recursively set gradient edges for all children
    replicate_grad_edges(*child, child_replicas, devices);

namespace parallel {

template <typename ModuleType>
std::vector<std::shared_ptr<ModuleType>> replicate(
    const std::shared_ptr<ModuleType>& module,
    const std::vector<Device>& devices) {
  std::vector<std::shared_ptr<ModuleType>> replicas;
  for (const auto& device : devices) {
  // Configure gradient edges to point from replcia parameters to original
  // module parameters. See [Replicating Modules]
  replicate_grad_edges(module, replicas, devices);
  return replicas;

template <typename ModuleType>
std::vector<ModuleHolder<ModuleType>> replicate(
    const ModuleHolder<ModuleType>& module,
    const std::vector<Device>& devices) {
  auto ptrs = replicate(module.ptr(), devices);
  return std::vector<ModuleHolder<ModuleType>>(ptrs.begin(), ptrs.end());

template <typename ModuleType>
std::vector<Tensor> parallel_apply(
    std::vector<ModuleType>& modules,
    const std::vector<Tensor>& inputs,
    const optional<std::vector<Device>>& devices = nullopt) {
      modules.size() == inputs.size(), "Must have as many inputs as modules");
  if (devices) {
        modules.size() == devices->size(),
        "Must have as many devices as modules");

  std::vector<Tensor> outputs(modules.size());
  std::mutex mutex;

  // std::exception_ptr can be passed between threads:
  // > An instance of std::exception_ptr may be passed to another function,
  // > possibly on another thread, where the exception may be rethrown [...].
  std::exception_ptr exception;

      [&modules, &inputs, &devices, &outputs, &mutex, &exception](
          int64_t index, int64_t stop) {
        for (; index < stop; ++index) {
          try {
            auto output = modules[index]->forward(inputs[index]);
            output =
       ? (*devices)[index] : inputs[index].device());
            std::lock_guard<std::mutex> lock(mutex);
            outputs[index] = output;
          } catch (...) {
            std::lock_guard<std::mutex> lock(mutex);
            if (!exception) {
              exception = std::current_exception();

  if (exception) {

  return outputs;

template <typename ModuleType>
Tensor data_parallel(
    ModuleType module,
    Tensor input,
    optional<std::vector<Device>> devices = nullopt,
    optional<Device> output_device = nullopt,
    int64_t dim = 0) {
  if (!devices) {
    const auto device_count = torch::cuda::device_count();
        device_count > 0, "Expected at least one CUDA device to be available");
    devices = std::vector<Device>();
    for (const auto index : c10::irange(device_count)) {
      devices->emplace_back(kCUDA, static_cast<torch::DeviceIndex>(index));
  if (!output_device) {
    output_device = devices->front();

  if (devices->size() == 1) {
    input =>front());
    return module->forward(std::move(input)).to(*output_device);

  autograd::Scatter scatter(*devices, /*chunk_sizes=*/nullopt, dim);
  auto scattered_inputs = fmap<Tensor>(scatter.apply({std::move(input)}));
  // Input tensor might not be big enough to scale across all available devices
  if (scattered_inputs.size() < devices->size()) {

  auto replicas = replicate(module, *devices);
  auto outputs = parallel_apply(replicas, scattered_inputs, *devices);
  return autograd::Gather(*output_device, dim)

} // namespace parallel
} // namespace nn
} // namespace torch


