by Preferred Networks

Welcome to the last entry into understanding the autograd engine of PyTorch series! If you haven’t read parts 1 & 2 check them now to understand how PyTorch creates the computational graph for the backward pass!

This post is based on PyTorch v1.11, so some highlighted parts may differ across versions.

PyTorch autograd graph execution

The last post showed how PyTorch constructs the graph to calculate the outputs’ derivatives w.r.t. the inputs when executing the forward pass. Now we will see how the execution of the backward pass is coordinated and done by looking at the whole process, starting from Python down to the lower C++ level internals.

What Happens when Calling backward()/grad() from Python

Using variable.backward()

After doing all our calculations with an input set to require the gradient, we call .backward() on the result to initiate the backward pass execution.

>>> x = torch.tensor([0.5, 0.75], requires_grad=True)
>>> y = torch.exp(x).sum()
>>> y.backward()

Calling .backward() on a tensor results in a call to torch.autograd.backward().

# torch/_tensor.py

def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
    
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

torch.autograd.backward() checks the arguments and calls the autograd engine in the C++ layer.

def backward(
    tensors: _TensorOrTensors,
    grad_tensors: Optional[_TensorOrTensors] = None,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    grad_variables: Optional[_TensorOrTensors] = None,
    inputs: Optional[_TensorOrTensors] = None,
) -> None:
    

    if inputs is not None and len(inputs) == 0:
        raise RuntimeError("'inputs' argument to backward() cannot be empty.")

    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
    inputs = (inputs,) if isinstance(inputs, torch.Tensor) else \
        tuple(inputs) if inputs is not None else tuple()

    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
    grad_tensors_ = _make_grads(tensors, grad_tensors_)
    if retain_graph is None:
        retain_graph = create_graph

    Variable._execution_engine.run_backward(
        tensors, grad_tensors_, retain_graph, create_graph, inputs,
        allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

First, whether the grad_tensors argument was specified or not, there is a call to the _make_grads function. This is used to check the provided grad_tensors or to specify the default value for them by looking at the tensors argument values’ shapes. Check the first blog post for details on the default value for the grad_tensors of the backward pass. This function just provides the vector of the vector jacobian product if it was not initially specified.

In the above code, Variable has an _execution_engine attribute that is defined in torch.autograd.variable to be of type ImperativeEngine; the C++ engine exported to python and declared in torch/csrc/autograd/python_engine.cpp. In the following sections, we explain in detail how this object executes the backward pass.

Note that the torch.autograd.backward function has an inputs optional argument. This argument is used when we want to calculate the .grad field of only a subset of input tensors in the forward pass.

>>> x = torch.tensor([0.5, 0.75], requires_grad=True)
>>> y = torch.tensor([0.1, 0.90], requires_grad=True)
>>> z = torch.exp(x * y).sum()
>>> torch.autograd.backward([z], inputs=[x])
>>> x.grad
tensor([0.1051, 1.7676])
>>> y.grad  # None
>>>

Using torch.autograd.grad

An alternative to backward() is to use torch.autograd.grad(). The main difference to backward() is that grad() returns a tuple of tensors with the gradients of the outputs w.r.t. the inputs kwargs instead of storing them in the .grad field of the tensors. As you can see, the grad() code shown below is very similar to backward.

def grad(
    outputs: _TensorOrTensors,
    inputs: _TensorOrTensors,
    grad_outputs: Optional[_TensorOrTensors] = None,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    only_inputs: bool = True,
    allow_unused: bool = False,
   is_grads_batched: bool = False
) -> Tuple[torch.Tensor, ...]:
   
    outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
    inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
    overridable_args = outputs + inputs
    if has_torch_function(overridable_args):
        return handle_torch_function(
            grad,
            overridable_args,
            outputs,
            inputs,
            grad_outputs=grad_outputs,
            retain_graph=retain_graph,
            create_graph=create_graph,
            only_inputs=only_inputs,
            allow_unused=allow_unused,
        )

    grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
    grad_outputs_ = _make_grads(outputs, grad_outputs_)

    if retain_graph is None:
        retain_graph = create_graph

    if is_grads_batched:
        # …. It will not be covered here
    else:
        return Variable._execution_engine.run_backward(
            outputs, grad_outputs_, retain_graph, create_graph, inputs,
            allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass

Figure 1 shows the computational graph with the backward() and grad() arguments highlighted in red and blue, respectively:

Fgiure 1: Correspondence of `backward`/`grad` arguments in the graphs.

Going Inside the Autograd Engine

Refreshing Concepts: Nodes and Edges

As we saw in 2 The computational graph comprises Node and Edge objects. Please read that post if you haven’t done it yet.

Nodes

Node objects are defined in torch/csrc/autograd/function.h, and they provide an overload of operator() for the associated function and a list of edges to do the graph traversal. Note that Node is a base class that autograd functions inherit from and override the apply method to execute the backward function.

struct TORCH_API Node : std::enable_shared_from_this<Node> {
 ...
 /// Evaluates the function on the given inputs and returns the result of the
  /// function call.
  variable_list operator()(variable_list&& inputs) {
  ...
  }

protected:
  /// Performs the `Node`'s actual operation.
  virtual variable_list apply(variable_list&& inputs) = 0;
  
  edge_list next_edges_;
  uint64_t topological_nr_ = 0;
  

There is an attribute called topological_nr_ in every node object. This number is used to optimize the graph execution as it allows to discard of graph branches under certain conditions. The topological number is the longest distance between this node and any leaf node and it is shown in Figure 2. Its main property is that for any pair of nodes x, y in a directed graph topo_nr(x) < topo_nr(y) means that there is no path from x to y. So this allows for reducing the number of paths in the graph in need of traversal. Check the topological_nr ) method comment for further details.

Figure 2: Example of the Topological Number calculation

Edges

The Edge object links Nodes together, and its implementation is straightforward.

struct Edge {
  ...
  /// The function this `Edge` points to.
  std::shared_ptr<Node> function;
  /// The identifier of a particular input to the function.
  uint32_t input_nr;
};

It only requires a function pointer to the Node and an input number that is the index of the output from the forward function this edge points to. When preparing the set of gradients before calling “function”, we know that what is flowing from this edge should be accumulated in the “input_nr”th argument. Note that the input/output name is flipped here and this is the input to the backward function. Edge objects are constructed using the gradient_edge function method.

 Edge gradient_edge(const Variable& self) {
    if (const auto& gradient = self.grad_fn()) {
      return Edge(gradient, self.output_nr());
    } else {
      return Edge(grad_accumulator(self), 0);
    }
  }

Entering the C++ Realm

Once that torch.autograd.backward() has been invoked, the THPEngine_run_backward routine starts the graph traversal. Following is a schema of the function body:

PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwargs)
{
  HANDLE_TH_ERRORS
  PyObject *tensors = nullptr;
  PyObject *grad_tensors = nullptr;
  unsigned char keep_graph = 0;
  unsigned char create_graph = 0;
  PyObject *inputs = nullptr;
  
  // Convert the python arguments to C++ objects
  const char *accepted_kwargs[] = { // NOLINT
      "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
      "allow_unreachable", "accumulate_grad", nullptr
  };
  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs,
        &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad))

 // Prepare arguments
 for(const auto i : c10::irange(num_tensors)) {
   // Check that the tensors require gradients
  }

  std::vector<Edge> output_edges;
  if (inputs != nullptr) {
     // Prepare outputs
  }

  {
      // Calls the actual autograd engine
    pybind11::gil_scoped_release no_gil;
    outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
  }
    // Clean up and finish
}

First, we prepare the input arguments after converting the PyObject arguments to actual C++ objects. The tensors list contains the tensors from which we start the backward pass. These tensors are converted to edges using torch::autograd::impl::gradient_edge and added to a list called roots where the graph traversal starts.

 edge_list roots;
  roots.reserve(num_tensors);
  variable_list grads;
  grads.reserve(num_tensors);
  for(const auto i : c10::irange(num_tensors)) {
    PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
       const auto& variable = THPVariable_Unpack(_tensor);
       auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
     roots.push_back(std::move(gradient_edge));

    PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
    if (THPVariable_Check(grad)) {
      const Variable& grad_var = THPVariable_Unpack(grad);
      grads.push_back(grad_var);
    } 
  }

Now, if the inputs argument was specified in backward or we used the torch.autograd.grad api, the following code creates a list of edges to accumulate the gradients in the specified tensors at the end of the computation. The engine uses this later to optimize the execution as it doesn’t add the gradients in all the leaf nodes, just the specified ones.

  std::vector<Edge> output_edges;
  if (inputs != nullptr) {
    int num_inputs = PyTuple_GET_SIZE(inputs);
    output_edges.reserve(num_inputs);
    for (const auto i : c10::irange(num_inputs)) {
      PyObject *input = PyTuple_GET_ITEM(inputs, i);
      const auto& tensor = THPVariable_Unpack(input);
      const auto output_nr = tensor.output_nr();
      auto grad_fn = tensor.grad_fn();
      if (!grad_fn) {
        grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
      }
      if (accumulate_grad) {
        tensor.retain_grad();
      }
      if (!grad_fn) {
        output_edges.emplace_back(std::make_shared<Identity>(), 0);
      } else {
        output_edges.emplace_back(grad_fn, output_nr);
      }
    }
  }

The next step is the actual graph traversal and node function execution, and finally, the cleanup and return.

  {
    // Calls the actual autograd engine
    pybind11::gil_scoped_release no_gil;
    auto& engine = python::PythonEngine::get_python_engine();
    outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
  }
  // Clean up and finish
}

Starting the Real Execution

engine.executeis present in torch/csrc/autograd/engine.cpp

There are two differentiated steps here:

Analyze the graph to find dependencies between functions Create worker threads that traverse the graph

Data Structures Used for the Execution

GraphTask

All the execution metadata is managed by the GraphTask class in torch/csrc/autograd/engine.h

struct GraphTask: std::enable_shared_from_this<GraphTask> {
  std::atomic<uint64_t> outstanding_tasks_{0};
  //  … 
  std::unordered_map<Node*, InputBuffer> not_ready_;
  std::unordered_map<Node*, int> dependencies_;

  struct ExecInfo {
     // …
  };
  std::unordered_map<Node*, ExecInfo> exec_info_;
  std::vector<Variable> captured_vars_;
  // …
  std::shared_ptr<ReadyQueue> cpu_ready_queue_;
};

Here we see a series of variables dedicated to maintaining the execution state. outstanding_tasks_ tracks the number of tasks left to be executed for the backward pass to complete. not_ready_ holds the input arguments for the Nodes that are not ready to be executed. dependencies_ track the number of predecessors that a Node has. As the count reaches 0, the Node is ready for execution; it is placed in a ready queue to be retrieved and executed later.

exec_info_ and the associated ExecInfo struct are used only when the inputs argument is specified or it is a call to autograd.grad(). They allow filter paths on the graph that are not needeed since only the gradients are calculated only for the variables in the inputs list.

captured_vars_ is where the results of the graph execution are temporarily stored if we used the torch.autograd.grad() api instead of torch.autograd.backward() since grad() returns the gradients as tensors instead of just filling the .grad field of the inputs.

NodeTask

The NodeTask struct is a basic class that holds an fn_ pointer to the node to execute, and an inputs_ buffer to store the input arguments to this function. Note that the functions executed by the backward pass are the derivatives specified in the derivatives.yaml file. or the user provided backward function when using custom functions as described in the second blog post.

The inputs_ buffer is also where the output gradients of the previously executed functions are aggregated, and it is defined as a std::vector<Variable> container with facilities to accumulate values at a given position.

struct NodeTask {
  std::weak_ptr<GraphTask> base_;
  std::shared_ptr<Node> fn_;
  // This buffer serves as an implicit "addition" node for all of the
  // gradients flowing here.  Once all the dependencies are finished, we
  // use the contents of this buffer to run the function.
  InputBuffer inputs_;
};

GraphRoot

The GraphRoot is a special function used to hold multiple input variables in a single place. The code is pretty simple as it only acts as a container of variables.

struct TORCH_API GraphRoot : public Node {
  GraphRoot(edge_list functions, variable_list inputs)
      : Node(std::move(functions)),
      outputs(std::move(inputs)) {
    for (const auto& t : outputs) {
      add_input_metadata(t);
    }
  }

  variable_list apply(variable_list&& inputs) override {
    return outputs;
  }

AccumulateGrad

This function is set during the graph creation in gradient_edge when the Variable object doesn’t have a grad_fn. This is, it is a leaf node.

    if (const auto& gradient = self.grad_fn()) {
      // …
    } else {
      return Edge(grad_accumulator(self), 0);
    }

The function body is defined in torch/csrc/autograd/functions/accumulate_grad.cpp and it essentially accumulates the input grads in the object’s .grad attribute.

auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
  check_input_variables("AccumulateGrad", grads, 1, 0);
  

  at::Tensor new_grad = callHooks(variable, std::move(grads[0]));
  std::lock_guard<std::mutex> lock(mutex_);

  at::Tensor& grad = variable.mutable_grad();
  accumulateGrad(
      variable,
      grad,
      new_grad,
      1 + !post_hooks().empty() /* num_expected_refs */,
      [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
  return variable_list();
}
}} // namespace torch::autograd



accumulateGrad does several checks on the tensors format and eventually performs the variable_grad += new_grad; accumulation.

Preparing the graph for execution

Now, let’s walk through Engine::execute. The first thing to do besides arguments consistency checks is to create the actual GraphTask object we described above. This object keeps all the metadata of the graph execution.

auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {

  validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
    return msg;
  });

  // Checks

  auto graph_task = std::make_shared<GraphTask>(
      /* keep_graph */ keep_graph,
      /* create_graph */ create_graph,
      /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
      /* cpu_ready_queue */ local_ready_queue);

  // If we receive a single root, skip creating extra root node
  // …
  // Prepare graph by computing dependencies
  // …
  // Queue the root 
  // …
  // launch execution
  // …
}

After creating the GraphTask, we use its associated function if we only have one root node. If we have multiple root nodes, we create a special GraphRoot object as described before.

  bool skip_dummy_node = roots.size() == 1;
  auto graph_root = skip_dummy_node ?
    roots.at(0).function :
    std::make_shared<GraphRoot>(roots, inputs);

The next step is to fill the dependencies_ map in the GraphTask object since the engine must know when it can execute a task. The outputs here is the inputs argument passed to the torch.autograd.backward() call in Python. But here, we have reversed the names since the gradients w.r.t. the inputs of the forward pass are now the outputs of the backward pass. And from now on, there is no concept of forward/backward, but only graph traversal and execution.

  auto min_topo_nr = compute_min_topological_nr(outputs);
  // Now compute the dependencies for all executable functions
  compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);

  if (!outputs.empty()) {
    graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr);
  }

Here we preprocess the graph for the execution of the nodes. First, compute_min_topological_nr is called to to obtain the minimum topological number of the tensors specified in outputs (0 if no inputs kwarg was supplied to .backward or input for .grad). This computation prunes paths in the graph that lead to input variables of which we don’t want/need to calculate the grads.

Second, is the compute_dependencies call. This function is a very simple graph traversal that starts with the root Node, and for each of the edges in node.next_edges() it increments the counter in dependencies_. Figure 3 shows the result of the dependencies calculation for the example graph. Note that the number of dependencies of any node is just the number of edges arriving at it.

Figure 3: Number of dependencies for each node

Finally, the init_to_execute call, this is the one that populates the GraphTask::exec_info_ map in case that inputs were specified in the python backward call. It iterates the graph again, starting from the root, and records in the exec_info_ map the intermediate nodes needed to calculate only the given inputs gradients.

  // Queue the root
  if (skip_dummy_node) {
    InputBuffer input_buffer(roots.at(0).function->num_inputs());
    auto input = inputs.at(0);


    input_buffer.add(roots.at(0).input_nr,
                      std::move(input),
                      input_stream,
                      opt_next_stream);

    execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
  } else {
    execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
  }
  // Avoid a refcount bump for the Future, since we check for refcount in
  // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1)
  // in dist_engine.cpp).
  auto& fut = graph_task->future_result_;
  fut->wait();
  return fut->value().toTensorVector();
}

And now, we are ready to start the actual execution by creating the InputBuffer. In case we only have one root variable, we begin by copying the value of the inputs tensor (this is the gradients passed to python backward) in position 0 of the input_buffer. This is a small optimization that avoids running the RootNode for no reason. Also, if the rest of the graph is not on the cpu, we directly start on that worker while the RootNode is always placed on the cpu ready queue. Details of the workers and ready queues are explained in the section below.

On the other hand, if we have multiple roots, the GraphRoot object also holds the inputs, so it is enough to pass it an empty InputBuffer.

Graph Traversal and Node Execution

Devices, Threads and Queues

Before diving into the actual execution, we need to see how the engine is structured.

First of all, the engine is multithreaded with one thread per device. For example, the caller thread is associated with the CPU while additional threads are created and associated with each GPU or other devices available in the system. Each thread tracks its device using thread-local storage in the worker_device variable. In addition, the threads have a queue of tasks to be executed also located in thread-local storage, the local_ready_queue. This is where work is queued for this thread to execute in the thread_main function that is explained later. You will wonder how the device where a task should be executed is decided. The InputBuffer class has a device() function that returns the first non-cpu device of all its tensors. This function is used together with Engine::ready_queue to select the queue to queue a task.

auto Engine::ready_queue(std::shared_ptr<ReadyQueue> cpu_ready_queue, at::Device device) -> std::shared_ptr<ReadyQueue>{
  if (device.type() == at::kCPU || device.type() == at::DeviceType::Meta) {
    return cpu_ready_queue;
  } else {
    // See Note [Allocating GPUs to autograd threads]
    return device_ready_queues_.at(device.index());
  }
}

The ReadyQueue object is defined in torch/csrc/autograd/engine.h and it is a simple wrapper over std::priority_queue that allows a thread to wait for a task if it’s empty. One interesting property of the ReadyQueue is that it increases the GraphTask::outstanding_tasks_ value used to determine if the execution has completed or not.

auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
  {
    std::lock_guard<std::mutex> lock(mutex_);
    if (incrementOutstandingTasks) {
      std::shared_ptr<GraphTask> graph_task = item.base_.lock();
      ++graph_task->outstanding_tasks_;
    }
    heap_.push(std::move(item));
  }
  not_empty_.notify_one();
}

auto ReadyQueue::pop() -> NodeTask {
  std::unique_lock<std::mutex> lock(mutex_);
  not_empty_.wait(lock, [this]{ return !heap_.empty(); });
  auto task = std::move(const_cast<NodeTask&>(heap_.top())); heap_.pop();
  return task;
}

Reentrant Backward

A reentrant backward happens when one of the tasks in a backward pass calls again backward. It is not a very common case, but it can be used to reduce memory utilization as it could potentially avoid saving intermediate results. For more information, check this PyTorch forum post.

class ReentrantBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.sum()

    @staticmethod
    def backward(ctx, input):
        # Let's compute the backward by using autograd
        input = input.detach().requires_grad_()
        with torch.enable_grad():
            out = input.sum()
        out.backward()  # REENTRANT CALL!!
        return out.detach()

Here, we call backward() inside backward() for a user custom-defined autograd function. This situation can lead to deadlocks because the first backward needs to wait for the second one to complete. But some internal implementation details can prevent the second backward from completing as it is explained in the dedicated subsection.

Thread Initialization

execute_with_graph_task is in charge of initializing the threads taking care of the computation and placing the root node in the queue of the device that produced it.

c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {

  initialize_device_threads_pool();
  // Lock mutex for GraphTask.
  std::unique_lock<std::mutex> lock(graph_task->mutex_);

  auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());

  if (worker_device == NO_DEVICE) {
    set_device(CPU_DEVICE);
    graph_task->owner_ = worker_device;
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
    lock.unlock();
    thread_main(graph_task);
    worker_device = NO_DEVICE;
  } else {
     // This deals with reentrant backwards, we will see it later.
  }
  return graph_task->future_result_;
}

First, this function initializes several threads (one per device) calling initialize_device_threads_pool() where several things happen: One ReadyQueue per device is created. One thread per non-cpu device is created. A thread local worker_device variable is set to track the current device associated with the thread. thread_main function is called, and threads wait for tasks to be put in their queues.

Then it retrieves the queue to place the root node based on the device that holds the tensors present in the input_buffer using the ready_queue function. Now, the main thread (the one also executing the Python interpreter) has its worker_device set to NO_DEVICE, and it is in charge of executing functions with all its tensors living in the cpu. If worker_device is set to any other value, the graph execution is already started, and .backward() was called inside a running Node, creating a reentrant backward call. This is explained later. For now, the main thread places the task in the queue and call thread_main.

Where the Magic Happens

It’s been a long way, but finally, we are ready to traverse the graph and execute the nodes. Each of the spawned threads, and the main thread call thread_main.

auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {

  while (graph_task == nullptr || !graph_task->future_result_->completed()) {
    std::shared_ptr<GraphTask> local_graph_task;
    {
      NodeTask task = local_ready_queue->pop();

      if (task.isShutdownTask_) {
        break;
      }

      if (!(local_graph_task = task.base_.lock())) {
        // GraphTask for function is no longer valid, skipping further
        // execution.
        continue;
      }

      if (task.fn_ && !local_graph_task->has_error_.load()) {
        at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);

        try {
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          {
            evaluate_function(
                local_graph_task,
                task.fn_.get(),
                task.inputs_,
                local_graph_task->cpu_ready_queue_);
          }
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }
    }

    // Decrement the outstanding tasks.
    --local_graph_task->outstanding_tasks_;

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      local_graph_task->mark_as_completed_and_run_post_processing();
      auto base_owner = local_graph_task->owner_;
      if (worker_device != base_owner) {
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }
  }
}

The code here is simple, given the local_ready_queue assigned to each thread in thread-local storage. The threads loop until there are no tasks left to execute in the graph. Note that for device-associated threads, the passed graph_task argument is nullptr, and they block in local_ready_queue->pop() until a task is pushed in their queue. After some consistency checks (the task type is shutdown, or the graph is still valid). We get to the actual function invocation in evaluate_function.

        try {
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          {
            evaluate_function(
                local_graph_task,
                task.fn_.get(),
                task.inputs_,
                local_graph_task->cpu_ready_queue_);
          }
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }

After calling evaluate_function, we check if the graph_task execution is complete by looking the outstanding_tasks_ number. This number increases when a task is pushed to a queue and is decreased in local_graph_task->completed() when a task is executed. When the execution is done, we return the results that are be in the captured_vars_ in case we called torch.autograd.grad() instead of torch.autograd.backward() as this function returns tensors instead of storing them in the .grad attribute of the inputs. Finally we wake up the main thread if it’s waiting by sending a dummy task.

   // Decrement the outstanding tasks.
    --local_graph_task->outstanding_tasks_;

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      local_graph_task->mark_as_completed_and_run_post_processing();
      auto base_owner = local_graph_task->owner_;
      if (worker_device != base_owner) {
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }

Calling the Function and Unlocking New Tasks

evaluate_function serves three purposes:

Run the function. Accumulate its results in the next node InputBuffers. Decrease the dependencies counter of the next nodes and enqueues the tasks reaching 0 to be executed.

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {

  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    // Checks if the function needs to be executed 
    if (!fn_info.needed_) {
      // Skip execution if we don't need to execute the function.
      return;
    }
  }

  auto outputs = call_function(graph_task, func, inputs);

  auto& fn = *func;
  if (!graph_task->keep_graph_) {
    fn.release_variables();
  }

Initially, we check the exec_info_ map of the GraphTask structure to determine if the current node needs to be executed. Remember that if this map is empty, all the nodes are executed because we are calculating the grads for all the inputs of the forward pass.

After this check, the function is executed by running call_function. Its implementation is very straightforward and calls the actual derivative function and registered hooks if any.

  int num_outputs = outputs.size();
  if (num_outputs == 0) {
    // Records leaf stream (if applicable)
    return;
  }

  if (AnomalyMode::is_enabled()) {
    // check for nan values in result
  }

Next, we check the outputs of the function after call_function is done. If the number of outputs is 0, there are no following nodes to be executed so we can safely return. This is the case of the AccumulateGrad node associated with the leaf nodes.

Also, the check for NaN values in the gradients is done here if requested.


  std::lock_guard<std::mutex> lock(graph_task->mutex_);
  for (const auto i : c10::irange(num_outputs)) {
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i);

    if (!next.is_valid()) continue;

   

We have now executed a grad_fn that has returned one gradient per each of the associated forward pass function inputs. As we saw in the previous blog post, we have an Edge object per each of these input tensors, and the grad_fn of the function producing them in the forward pass. Essentially, Output[0] of the node in the backward pass, corresponds to the first argument of the forward pass associated function. Figure 4 shows how the outputs of a backward function are related to the inputs of the forward function. See that the outputs of grad_fn C are the gradients of z w.r.t. the inputs of Function C

Figure 4: Correspondence between forward and backward functions inputs and outputs

We now iterate through these edges and check if the associated functions are ready to be executed.

 // Check if the next function is ready to be computed
    bool is_ready = false;
    auto& dependencies = graph_task->dependencies_;
    auto it = dependencies.find(next.function.get());

    if (it == dependencies.end()) {
      auto name = next.function->name();
      throw std::runtime_error(std::string("dependency not found for ") + name);
    } else if (--it->second == 0) {
      dependencies.erase(it);
      is_ready = true;
    }

    auto& not_ready = graph_task->not_ready_;
    auto not_ready_it = not_ready.find(next.function.get());

For this, we check the graph_task->dependencies_ map. We decrement the counter, and if it reaches 0, we mark the function pointed by the edge ready to be executed. Following, we prepare the input buffers of the tasks indicated by the next edges.

    if (not_ready_it == not_ready.end()) {
      if (!exec_info_.empty()) {
        // Skip functions that aren't supposed to be executed
      }

      // Creates an InputBuffer and moves the output to the corresponding input position
      InputBuffer input_buffer(next.function->num_inputs());
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);

      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
      } else {
        not_ready.emplace(next.function.get(), std::move(input_buffer));
      }

Here, we look for the task in the graph_task->not_ready_ map. If it is not present, we create a new InputBuffer object and set the current output in the input_nr position of the buffer associated with the edge. If the task is ready to be executed, we enqueue it in the appropriate device ready_queue and complete the execution. However, if the task is not ready and we have seen it before, it is present in the not_ready_map_.

    } else {
      // The function already has a buffer
      auto &input_buffer = not_ready_it->second;
      // Accumulates into buffer
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);
      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(NodeTask(graph_task, next.function, std::move(input_buffer)));
        not_ready.erase(not_ready_it);
      }
    }
  }
}

In this case, we accumulate the output in the existing input_buffer instead of creating a new one. Once all the tasks are processed, the worker thread exits the loop and complete. All this process is summarized in the animation in Figure 5. We see how a thread peeks at the tasks in the ready queue and decrements the next nodes’ dependencies, unlocking them for execution.

Figure 5: Animation of the execution of the computational graph

Flow with Reentrant Backward

As we saw above, the reentrant backward problem is when the currently executed function does a nested call to backward. When this happens, the thread running this function goes all the way down to execute_with_graph_task as in the non-reentrant case, but here is when things are different.

c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {

  initialize_device_threads_pool();
  // Lock mutex for GraphTask.
  std::unique_lock<std::mutex> lock(graph_task->mutex_);

  auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());

  if (worker_device == NO_DEVICE) {
    //Regular case
  } else {
    // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
    //    backward call from that device.
    graph_task->owner_ = worker_device;

    // Now that all the non-thread safe fields of the graph_task have been populated,
    // we can enqueue it.
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));

    if (current_depth >= max_recursion_depth_) {
      // If reached the max depth, switch to a different thread
      add_thread_pool_task(graph_task);
    } else {
      ++total_depth;
      ++current_depth;
      lock.unlock();
      thread_main(graph_task);
      --current_depth;
      --total_depth;
    }
  }
  return graph_task->future_result_;
}

Here, execute_with_graph_task detects this as a reentrant call and then looks for the current number of nested calls. If it exceeds the limit, we create a new thread to take care of the execution of this graph, and if not, we execute this reentrant call regularly. The limit of nested calls was originally set to avoid stack overflow due to reentrant calls creating very large call stacks. However, the number was further reduced when sanitizer tests were added because of the maximum amount of locks a thread can hold at a given moment. This can be seen in torch/csrc/autograd/engine.h.

When this maximum depth is exceeded, a new thread is created with the add_thread_pool_task function.

void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
  std::unique_lock<std::mutex> lck(thread_pool_shared_->mutex_);
  // if we have pending graph_task objects to be processed, create a worker.
   bool create_thread = (thread_pool_shared_->num_workers_ <= thread_pool_shared_->graphtasks_queue_.size());
  thread_pool_shared_->graphtasks_queue_.push(graph_task);


  lck.unlock();
  if (create_thread) {
    std::thread t(&Engine::reentrant_thread_init, this);
    t.detach();
  }

  thread_pool_shared_->work_.notify_one();
}



Before going in-depth, let’s look at the thread_pool_shared_ object in the Engine which manages all the information related to the threads associated to the reentrant backward calls.

  struct ThreadPoolShared {
    unsigned int num_workers_;
    std::condition_variable work_;
    std::mutex mutex_;
    std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;

    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
    ThreadPoolShared() : num_workers_(0) {}
 };



ThreadPoolShared is a simple container holding a queue of GraphTask objects with synchronization mechanisms and the number of current workers.

Now it is easy to understand how add_thread_pool_task creates a thread when there are graph_task objects enqueued and insufficient workers to process them.

add_thread_pool_task initializes a thread by executing reentrant_thread_init

void Engine::reentrant_thread_init() {
  at::init_num_threads();
  auto tp_shared = thread_pool_shared_;
  while(true) {
    std::unique_lock<std::mutex> lk(tp_shared->mutex_);
    ++thread_pool_shared_->num_workers_;
    tp_shared->work_.wait(lk, [&tp_shared]{ return !tp_shared->graphtasks_queue_.empty();});
    --thread_pool_shared_->num_workers_;
    auto task = tp_shared->graphtasks_queue_.front();
    tp_shared->graphtasks_queue_.pop();
    lk.unlock();
    std::shared_ptr<GraphTask> graph_task;
    if (!(graph_task = task.lock())) {
      continue;
    }
    set_device(graph_task->owner_);
    // set the local_ready_queue to the ready queue on the graph_task->owner_ device
    local_ready_queue = ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_);
    total_depth = graph_task->reentrant_depth_;
    thread_main(graph_task);
  }
}



The code is straightforward. The newly created thread waits on the thread_pool_shared->graphtasks_queue_ for reentrant backward graphs to be available and executes them. Notice that this thread uses the task-ready queue associated with the device of the thread that started this call by accessing the graph_task->owner_ field set in the execute_with_graph_task function.

Error Handling

Whenever an error happens in one of the worker threads. It will be propagated to the backward calling thread.

To achieve this, there is a try/catch block in the thread_main that catches any exception in the Node function call and sets it to the associated GraphTask object.

       try {
          
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          {
            evaluate_function(
               
          }
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }
    }

thread_on_exception and the functions it calls end up setting the exception in the local_graph_task object.

void Engine::thread_on_exception(
    std::shared_ptr<GraphTask> graph_task,
    const std::shared_ptr<Node>& fn,
    std::exception& e) {
  graph_task->set_exception(std::current_exception(), fn);
}

void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
  if (!has_error_.exchange(true)) {
    if (AnomalyMode::is_enabled() && fn) {
      fn->metadata()->print_stack(fn->name());
    }
  }
}

void GraphTask::set_exception(
    std::exception_ptr eptr,
    const std::shared_ptr<Node>& fn) {
  set_exception_without_signal(fn);
  if (!future_completed_.exchange(true)) {
    // NOLINTNEXTLINE(performance-move-const-arg)
    future_result_->setError(std::move(eptr));
  }
}

In set_exception it sets the has_error_ flag to true and it calls the setError function of the future_result_ object. This will make the error to be re-thrown at the caller thread when future_result_->value() is accessed.

 IValue value() {
    std::unique_lock<std::mutex> lock(mutex_);
    AT_ASSERT(completed());
    if (eptr_) {
      std::rethrow_exception(eptr_);
    }
    return value_;
  }

Closing Remarks

This has been the last post of this series covering how PyTorch does the auto differentiation. We hope you enjoyed reading it and that now you are familiar enough with PyTorch internals to start contributing in PyTorch development!