• Docs >
  • DDP Communication Hooks
Shortcuts

DDP Communication Hooks

DDP communication hook is a generic interface to control how to communicate gradients across workers by overriding the vanilla allreduce in DistributedDataParallel. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication. Besides, the hook interface can also support user-defined communication strategies for more advanced use cases.

How to Use a Communication Hook?

To use a communication hook, the user just needs to let the DDP model register the hook before the training loop as below.

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

What Does a Communication Hook Operate On?

Communication hook provides a flexible way to allreduce gradients. Therefore, it mainly operates on the gradients on each replica before allreduce, which are bucketized to increase the overlap between communication and computation. Particularly, torch.distributed.GradBucket represents a bucket of gradient tensors to be allreduced.

class torch.distributed.GradBucket

This class mainly passes a flattened gradient tensor (returned by buffer()) to DDP communication hook. This tensor can be further decomposed into a list of per-parameter tensors within this bucket (returned by get_per_parameter_tensors()) to apply layer-wise operations.

torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket)int

Warning

Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training.

Returns

The index of a bucket that stores gradients of a few contiguous layers. All the gradients are bucketized.

torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket)at::Tensor
Returns

A flattened 1D torch.Tensor buffer, which can be further decomposed into a list of per-parameter tensors within this bucket.

torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket)List[at::Tensor]
Returns

A list of torch.Tensor. Each tensor in the list corresponds to a gradient.

torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket)bool
Returns

Whether this bucket is the last bucket to allreduce in an iteration. This also means that this bucket corresponds to the first few layers in the forward pass.

torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: at::Tensor)None

Replaces the tensor in the bucket with the input tensor buffer.

torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket)List[at::Tensor]
Returns

A list of torch.Tensor. Each tensor in the list corresponds to a model parameter.

Default Communication Hooks

Default communication hooks are simple stateless hooks, so the input state in register_comm_hook is either a process group or None. The input bucket is a torch.distributed.GradBucket object.

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source]

This DDP communication hook just calls allreduce using GradBucket tensors. Once gradient tensors are aggregated across all workers, its then callback takes the mean and returns the result. If user registers this hook, DDP results is expected to be same as the case where no hook was registered. Hence, this won’t change behavior of DDP and user can use this as a reference or modify this hook to log useful information or any other purposes while unaffecting DDP behavior.

Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source]

This DDP communication hook implements a simple gradient compression approach that casts GradBucket tensor to half-precision floating-point format (torch.float16) and then divides it by the process group size. It allreduces those float16 gradient tensors. Once compressed gradient tensors are allreduced, the chained callback decompress casts it back to the input data type (such as float32).

Example::
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source]

Warning: This API is experimental, and it requires NCCL version later than 2.9.6.

This DDP communication hook implements a simple gradient compression approach that casts GradBucket tensor to half-precision Brain floating point format (torch.bfloat16) and then divides it by the process group size. It allreduces those bfloat16 gradient tensors. Once compressed gradient tensors are allreduced, the chained callback decompress casts it back to the input data type (such as float32).

Example::
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)

Additionally, a communication hook wraper is provided to support fp16_compress_hook() or bf16_compress_hook() as a wrapper, which can be combined with other communication hooks.

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source]

This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision floating point format (torch.float16), and casts the resulting tensor of the given hook back to the input data type, such as float32.

Therefore, fp16_compress_hook is equivalent to fp16_compress_wrapper(allreduce_hook).

Example::
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source]

Warning: This API is experimental, and it requires NCCL version later than 2.9.6.

This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16`), and casts the resulting tensor of the given hook back to the input data type, such as float32.

Therefore, bf16_compress_hook is equivalent to bf16_compress_wrapper(allreduce_hook).

Example::
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))

PowerSGD Communication Hook

PowerSGD (Vogels et al., NeurIPS 2019) is a gradient compression algorithm, which can provide very high compression rates and accelerate bandwidth-bound distributed training. This algorithm needs to maintain both some hyperparameters and the internal state. Therefore, PowerSGD communication hook is a stateful hook, and the user needs to provide a state object defined as below.

PowerSGD State

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source]

Stores both the algorithm’s hyperparameters and the internal state for all the gradients during the training. Particularly, matrix_approximation_rank and start_powerSGD_iter are the main hyperparameters that should be tuned by the user. For performance, we suggest to keep binary hyperparameters use_error_feedback and warm_start on.

  1. matrix_approximation_rank controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.

    1.1. If matrix_approximation_rank is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.

    1.2. The increase of matrix_approximation_rank can substantially increase the computation costs of the compression, and the accuracy may not be futher improved beyond a certain matrix_approximation_rank threshold.

To tune matrix_approximation_rank, we suggest to start from 1 and increase by factors of 2 (like an expoential grid search, 1, 2, 4, …), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.

  1. start_powerSGD_iter defers PowerSGD compression until step start_powerSGD_iter, and vanilla allreduce runs prior to step start_powerSGD_iter. This hybrid scheme of vanilla allreduce + PowerSGD can effectively improve the accuracy, even a relatively small matrix_approximation_rank is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.

To tune start_powerSGD_iter, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, start_powerSGD_iter typically should be no less than the number of warm-up steps.

  1. min_compression_rate is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression.

Compression statistics are logged every compression_stats_logging_frequency iterations once PowerSGD compression starts.

  1. orthogonalization_epsilon can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.

  2. batch_tensors_with_same_shape controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., bucket_cap_mb arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to True if the compression / decompression computation is a bottleneck.

Warning

If error feedback or warm-up is enabled, the minimum value of start_powerSGD_iter allowed in DDP is 2. This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, and this can conflict with any tensor memorized before the rebuild process.

PowerSGD Hooks

Warning

PowerSGD typically requires extra memory of the same size as the model’s gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy.

Warning

PowerSGD hooks may conflict with Apex automatic mixed precision package. Please use PyTorch native automatic mixed precision package instead.

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source]

This DDP communication hook implements PowerSGD gradient compression algorithm described in the paper. Once gradient tensors are aggregated across all workers, this hook applies compression as follows:

  1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups:

    1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth.

    1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases).

  2. Handles uncompressed tensors:

    2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression;

    2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor.

  3. Handles the tensors that should be compressed by PowerSGD compression:

    3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;

    3.2. Computes each P in Ps, which is equal to MQ;

    3.3. Allreduces Ps as a batch;

    3.4. Orthogonalizes each P in Ps;

    3.5. Computes each Q in Qs, which is approximately equal to M^TP;

    3.6. Allreduces Qs as a batch;

    3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T.

Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.

Parameters
  • state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank, start_powerSGD_iter and min_compression_rate.

  • bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode, only exactly one tensor is stored in this bucket.

Returns

Future handler of the communication, which updates the gradients in place.

Example::
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source]

This DDP communication hook implements a simplified PowerSGD gradient compression algorithm described in the paper. This variant does not compress the gradients layer by layer, but instead compresses the flattened input tensor that batches all the gradients. Therefore, it is faster than powerSGD_hook(), but usually results in a much lower accuracy, unless matrix_approximation_rank is 1.

Warning

Increasing matrix_approximation_rank here may not necessarily increase the accuracy, because batching per-parameter tensors without column/row alignment can destroy low-rank structure. Therefore, the user should always consider powerSGD_hook() first, and only consider this variant when a satisfactory accuracy can be achieved when matrix_approximation_rank is 1.

Once gradient tensors are aggregated across all workers, this hook applies compression as follows:

  1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;

  2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;

  3. Computes P, which is equal to MQ;

  4. Allreduces P;

  5. Orthogonalizes P;

  6. Computes Q, which is approximately equal to M^TP;

  7. Allreduces Q;

  8. Computes M, which is approximately equal to PQ^T.

  9. Truncates the input tensor to the original length.

Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.

Parameters
  • state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank and start_powerSGD_iter.

  • bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode, only exactly one tensor is stored in this bucket.

Returns

Future handler of the communication, which updates the gradients in place.

Example::
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

Debugging Communication Hooks

As the name implies, debugging communication hooks are only used for debugging and performance optimization purpose.

Warning

Debugging communication hooks do not necessarily output the correct results.

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source]

This DDP communication hook returns a future that wraps the input, so it is a noop that does not incur any communication overheads.

This hook should only be used for headroom analysis of allreduce optimization, instead of the normal gradient synchronization. For example, if only less than 10% speedup of training time can be observed after this hook is registered, it usually implies that allreduce is not a performance bottleneck for this case. Such instrumentation can be particularly useful if GPU traces cannot be easily retrieved or the trace analysis is complicated some factors such as the overlap between allreduce and computation or the desynchronization across ranks.

Example::
>>> ddp_model.register_comm_hook(None, noop_hook)

Acknowledgements

Many thanks to PowerSGD paper author Thijs Vogels for the code review on PowerSGD communication hook, as well as the comparison experiments, which show that the performance of PowerSGD communication hook is on par with the implementation in the original paper.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources