Source code for torch.distributed.pipelining.microbatch
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.fx.node import map_aggregate
from torch.utils._pytree import tree_flatten, tree_unflatten
__all__ = [
"TensorChunkSpec",
"split_args_kwargs_into_chunks",
"merge_chunks",
]
logger = logging.getLogger(__name__)
"""
_debug_mask_minibatches specifies to send masked versions of the mini-batch
through instead of micro-batch slices--this can be used for more stable
numerical testing (see [A Note About Correctness Testing])
"""
_debug_mask_minibatches = False
class _CustomReducer:
"""
Custom reducer class that can be used to specify a custom operation that
reduces losses of multiple microbatches into one value.
Example:
>>> # xdoctest: +SKIP
>>> sum_reducer = _CustomReducer(
>>> torch.tensor(0.0),
>>> lambda a, b: a + b
>>> )
"""
def __init__(self, init_value, reduce_fn):
self.init_value = init_value
self.reduce_fn = reduce_fn
class _LossReducer(_CustomReducer):
pass
sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
# Default chunking dimension is 0. This is used for the case where the user did
# not specify a chunking dimension.
DEFAULT_CHUNK_DIM = 0
[docs]class TensorChunkSpec:
"""
Class used to specify chunking of inputs
"""
def __init__(self, split_dim):
self.split_dim = split_dim
split_dim: int
def __repr__(self):
return (
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
)
def __str__(self):
return f"TensorChunkSpec({self.split_dim})"
@staticmethod
def from_tuple(
chunk_dims: Tuple[int, ...],
):
"""
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # There are three positional arguments to the model, and
>>> # we are chunking them along dimension 0, 0 and 1, respectively
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
"""
args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
)
return args_chunk_spec
@staticmethod
def from_dict(
chunk_dims: Dict[str, int],
):
"""
A helper for creating a dictionary of `TensorChunkSpec` from a
dictionary of chunk dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
"""
kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
)
return kwargs_chunk_spec
# Class used to specify replication of inputs
class _Replicate:
pass
def _shard_dict_of_args(
args_dict,
args_chunk_spec,
num_chunks,
):
"""
Given a dictionary of args, and a dictionary of chunking specs, shard the
args according to the chunking specs.
Args:
args_dict: Dictionary of args
args_chunk_spec: Dictionary of chunking specs
num_chunks: Number of chunks to shard the args into
Returns:
args_split: List of sharded args
"""
# Stage 1+2: flatten and shard/replicate
# args_sharded_replicated : [num args, num flat values, num chunks]
args_sharded_replicated = {}
arg_specs = []
real_num_chunks = num_chunks
first_tensor = True
assert len(args_dict) == len(
args_chunk_spec
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
for arg_key, arg in args_dict.items():
flat, spec = tree_flatten(arg)
arg_specs.append(spec)
chunk_spec = args_chunk_spec[arg_key]
assert chunk_spec is not None # Should have been set by caller
chunk_spec_flat, _ = tree_flatten(chunk_spec)
if len(flat) != len(chunk_spec_flat):
raise ValueError(
f"Argument value {arg} did not have the same number of "
f"values as as chunk spec {chunk_spec}"
)
sharded_arg_flat = []
for v, chunk_v in zip(flat, chunk_spec_flat):
if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
sharded_arg_flat.append([v] * real_num_chunks)
elif isinstance(chunk_v, TensorChunkSpec):
# TODO: check type of v. If it's a tensor, use chunk (or debug mask).
# If it's a collection type, split it as you would expect. Otherwise,
# Throw an error
assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
v_split_dim_size = v.size(chunk_v.split_dim)
if v_split_dim_size < real_num_chunks:
if first_tensor:
# We can only adjust number of chunks when we hit this
# issue at the first tensor encountered
logger.warning(
f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
)
real_num_chunks = v_split_dim_size
else:
raise RuntimeError(
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
f"smaller than the number of chunks {num_chunks}. "
"PiPPy cannot reduce the number of chunks because "
"other arguments have bigger chunk-dimension sizes. "
"Please adjust your num_chunks setting."
)
chunk_tensors = torch.tensor_split(
v, real_num_chunks, chunk_v.split_dim
)
if _debug_mask_minibatches:
expanded_chunks = []
split_dim_idx = 0
for chunk_tensor in chunk_tensors:
new_val = torch.zeros_like(v)
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
slice_indices = [slice(None, None, None)] * new_val.ndim
slice_indices[chunk_v.split_dim] = slice(
split_dim_idx, upper_idx
)
new_val[slice_indices] = chunk_tensor
expanded_chunks.append(new_val)
split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
sharded_arg_flat.append(expanded_chunks)
else:
sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
first_tensor = False
else:
raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
args_sharded_replicated[arg_key] = sharded_arg_flat
# chunks_flat : [num chunks, num args, num flat values]
chunks_flat = []
for chunk_idx in range(real_num_chunks):
chunk_args = {}
for key, arg in args_sharded_replicated.items():
arg_single_chunk = []
for v_flat in arg:
arg_single_chunk.append(v_flat[chunk_idx])
chunk_args[key] = arg_single_chunk
chunks_flat.append(chunk_args)
# args_split : [num chunks, num args]
args_split = []
for chunk in chunks_flat:
per_chunk_args = {}
assert len(arg_specs) == len(chunk)
for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
per_chunk_args[key] = tree_unflatten(arg, arg_spec)
args_split.append(per_chunk_args)
return args_split
[docs]def split_args_kwargs_into_chunks(
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
chunks: int,
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
) -> Tuple[List[Tuple], List[Dict]]:
"""
Given a sequence of args and kwargs, split them into a number of chunks
according to their respective chunking specs.
Args:
args: Tuple of args
kwargs: Dict of kwargs
chunks: Number of chunks to split the args and kwargs into
args_chunk_spec: chunking specs for args, in same shape as args
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
Returns:
args_split: List of sharded args
kwargs_split: List of sharded kwargs
"""
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
# and `kwargs_chunk_spec` specifications. The steps are as follows:
#
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
# To use a running example: suppose our inputs look like
#
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
# (kwargs not shown but it's a similar process)
#
# Then for this step we would end up with
#
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
#
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
#
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
#
# 3. Rotate the nesting order such that chunks are the outer dimension
#
# args_chunks = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 4. Unflatten each chunk according to the spec
#
# args_chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# TODO: _debug_mask_minibatches
# Handle the case where kwargs is None
if kwargs is None:
kwargs = {}
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
# their format and use default chunking along dim 0
if args_chunk_spec is None:
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
if kwargs_chunk_spec is None:
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
chunks,
)
real_num_chunks = len(args_split_dict)
kwargs_split = _shard_dict_of_args(
kwargs,
kwargs_chunk_spec,
real_num_chunks,
)
if len(kwargs_split) < real_num_chunks:
# In case kwargs are sharded into less chunks
# e.g. when `args` has no tensor, just values
real_num_chunks = len(kwargs_split)
# Re-shard args
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
real_num_chunks,
)
if len(args_split_dict) != len(kwargs_split):
raise RuntimeError(
"args and kwargs are split into different number of chunks: "
f"{len(args_split_dict)}, {len(kwargs_split)}"
)
args_split = []
for chunk_args in args_split_dict:
args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args))))
return args_split, kwargs_split
[docs]def merge_chunks(
chunks: List[Any],
chunk_spec,
):
"""
Given a list of chunks, merge them into a single value according to
the chunk spec.
Args:
chunks: list of chunks
chunk_spec: Chunking spec for the chunks
Returns:
value: Merged value
"""
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
# steps are similar to the steps in that function but in reverse. Given the
# input values:
#
# chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# args_spec = ([None, [None, TensorChunkSpec]], None)
#
# 1. Flatten the chunks according to the chunk_spec
#
# chunks_flat = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 2. Rotate the nesting order such that chunks are the inner dimension
#
# value_inner = ([A, B, [C_1, C_2]], D)
#
# 3. Concatenate sharded arguments
#
# value_combined = ([A, B, C], D)
#
# 4. Unflatten the combined args given the spec
#
# value = ([A, [B, C]], D)
# Preliminary: flatten the chunk spec
if chunk_spec is not None:
spec_flattened, flatten_spec = tree_flatten(chunk_spec)
else:
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec
chunk0_flat, flatten_spec = tree_flatten(chunks[0])
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
# Stage 1: flatten chunks
# chunks_flattened : [num chunks, num args]
chunks_flattened = []
for chunk in chunks:
chunk_flattened, _ = tree_flatten(chunk)
if len(chunk_flattened) != len(spec_flattened):
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
chunks_flattened.append(chunk_flattened)
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
# concatenate sharded operands
# args_flattened : [num args]
args_flattened = []
for arg_idx, arg in enumerate(spec_flattened):
if isinstance(arg, TensorChunkSpec):
partial_values = [
chunks_flattened[chunk_idx][arg_idx]
for chunk_idx in range(len(chunks_flattened))
]
if _debug_mask_minibatches:
# Infer size of individual chunks by running `tensor_split` again
overall_shape = partial_values[0].shape
for val in partial_values[1:]:
assert val.shape == overall_shape
meta_chunks = torch.tensor_split(
torch.empty(*overall_shape, device="meta"),
sections=len(partial_values),
dim=arg.split_dim,
)
values_to_cat = []
chunk_start_idx = 0
assert len(partial_values) == len(meta_chunks)
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
slice_indices = [slice(None, None, None)] * partial_value.ndim
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
sliced = partial_value[slice_indices]
values_to_cat.append(sliced)
chunk_start_idx = chunk_end_idx
else:
values_to_cat = partial_values
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
elif isinstance(arg, _CustomReducer):
reduced_val = arg.init_value
for chunk_idx in range(len(chunks_flattened)):
reduced_val = arg.reduce_fn(
reduced_val, chunks_flattened[chunk_idx][arg_idx]
)
args_flattened.append(reduced_val)
else:
value = chunks_flattened[0][arg_idx]
for chunk_idx in range(1, len(chunks_flattened)):
assert chunks_flattened[chunk_idx][arg_idx] == value
args_flattened.append(value)
# Stage 4: Unflatten combined args
return tree_unflatten(args_flattened, flatten_spec)