Distributed Optimizer with TorchScript support


Distributed Optimizer with TorchScript support is introduced in PyTorch 1.8 as a beta feature. This API is subject to change.

In this recipe, you will learn:

  • The high-level idea of distributed optimizer with TorchScript support and what this feature brings
  • How to write customized distributed optimizer that enables TorchScript support

What is Distributed Optimizer?

DistributedOptimizer takes a list of remote parameters (RRef) and runs the optimizer locally on the workers where the parameters live, which is commonly used together with Distributed RPC/Autograd to do model parallel training. It could use any of the local optimizer algorithms (either pre-defined algorithms provided in torch.optim or custom defined ones) to apply the gradients on each worker.

What is Distributed Optimizer with TorchScript support?

Distributed Optimizer are widely used in distributed model parallel training, and in some common use cases, training need to be done in multithreaded manner instead of multiprocess due to performance concern and resource utilizations (or at least partially multithreaded, i.e. Parameter Server hosting part of the model and parameters, with new thread updating the parameters per request). PyTorch itself does not support multithreaded training natively as it suffers from the Python’s Global Interpreter Lock (GIL), but it could leverage TorchScript to get rid of GIL and run the model in a multithreaded way.

For critical model training workloads, improving the training performance is an important topic. Researchers often would like to implement different optimization strategies with the graph representation (i.e. via operator fusion) or implement custom operator kernels in order to speed up training.

Distributed Optimizer with TorchScript support could help getting rid of GIL, thus improve PyTorch’s training performance in the multithreaded environment, it also unlocks the potential to further enhance the performance by using advanced compiler technologies that TorchScript offers (i.e. CPU/GPU fusion).

How to write a customized distributed optimizer with TorchScript support?

The code below shows how to write a customized distributed optimizer given an existing local optimizer implementation, which unlocks the TorchScript benefits including GIL removal and performance improvement opportunities.

Suppose that you already have a local optimizer that is currently used during training, In this case we will use quasi-hyperbolic momentum (QHM) as an example to show how to enable the TorchScript support, note that it also applies to any custom optimizers that inherits from torch.optim.Optimizer.

First, we need to separate the computation and state management from the optimizer implementation, this is so that we could extract the computation part and make it a free function, which is TorchScript friendly. It has two benefits: 1. The computation logic becomes easier to inspect, it allows us to quickly turn the parameter update/computation part into TorchScript, and utilize TorchScript IR to do further optimizations (operator fusion, etc.) 2. Distributed Optimizer underlying is using a different mechanisms to get gradients and update parameters (we store gradients separately instead of directly populating the param.grad field during backward). Separating the computation allows distributed optimizer to enable the possibility of optimizer update in multithreaded mode, as it eliminates the possible race condition to param.grad.

import torch
from torch import Tensor
from typing import List

def qhm_update(params: List[Tensor],
            dp_list: List[Tensor],
            momentum_buffer_list: List[Tensor],
            lr: float,
            nu: float,
            weight_decay: float,
            weight_decay_type: str,
            momentum: float):

    for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list):
        if weight_decay != 0:
            if weight_decay_type == "grad":
                d_p.add_(weight_decay, p)
            elif weight_decay_type == "direct":
                p.mul_(1.0 - lr * weight_decay)
                raise ValueError("Invalid weight decay type provided")

        momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p) * nu, momentum_buffer) * (1.0 - nu), d_p)

Next we will define a distributed functional optimizer with TorchScript compatability to manage the optimizer states and calls into the TorchScript compatible update function we defined above. Note that a few conventions are different from normal custom optimizers: 1. We don’t inherit torch.optim.Optimizer as TorchScript does not support polymorphism 2. step takes gradients list instead of the loss closure.

import torch
from torch import Tensor
from typing import List, Optional, Dict

# define this as a TorchScript class
class FunctionalQHM(object):
    def __init__(self,
                params: List[Tensor],
                lr: float,
                momentum: float,
                nu: float,
                weight_decay: float = 0.0,
                weight_decay_type: str = "grad"):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if weight_decay_type not in ("grad", "direct"):
            raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))

        self.defaults = {
            "lr": lr,
            "momentum": momentum,
            "nu": nu,
            "weight_decay": weight_decay,
        self.weight_decay_type = weight_decay_type

        # NOTE: we only have one param_group here and don't allow user to add additional
        # param group as it's not a common use case.
        self.param_group = {"params": params}

        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})

    def step(self, gradients: List[Optional[Tensor]]):
        params = self.param_group['params']
        params_with_grad = []
        grads = []
        momentum_buffer_list: List[Tensor] = []

        if len(params) != len(gradients):
            raise ValueError(
                "the gradients passed in does not equal to the size of the parameters!"
                + f"Params length: {len(params)}. "
                + f"Gradients length: {len(gradients)}"

        for param, gradient in zip(self.param_group['params'], gradients):
            if gradient is not None:
                state = self.state[param]
                state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format)

        # calls into the update function we just defined
        with torch.no_grad():

Finally, we register our newly defined distributed functional optimizer into the functional_optim_map This is so that the DistributedOptimizer will try to pick up our custom implementation instead of the pre-defined default ones.

from torch.distributed.optim import DistributedOptimizer

DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM

Now you can use the QHM optimizer as normal in distributed training by passing it to DistributedOptimizer

remote_params_list = [...]
dist_optim = DistributedOptimizer(
    QHM, remote_params_list, *args, **kwargs

DistributedOptimizer will automatically transform the QHM optimizer into the FunctionalQHM under the hood, and enable the TorchScript support. This will unlock the performance that boosted by multithreaded training and also give more potentials for further improvements (i.e. TorchScript fusion, etc.)

Note that majority of PyTorch built-in optimizers are already using this methodology to speed up distributed training. If you see warning about some optimizers haven’t been converted yet, you can write your own conversion by following this recipe.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources