Source code for torchtune.training._distributed
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from itertools import chain
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.checkpoint.state_dict import (
_init_optim_state,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import ShardingStrategy
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune.modules import TransformerDecoder
from torchtune.modules.attention import MultiHeadAttention
from torchtune.modules.model_fusion import DeepFusionModel, EarlyFusionModel
from torchtune.modules.peft import get_adapter_state_dict
from torchtune.utils import get_device, get_logger
from torchtune.utils._logging import deprecated
_log: logging.Logger = get_logger()
torch_version = torch.__version__
# TODO: Fix issues with DSD before uncommenting. See #2313 and #2277.
# _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = (
# "dev" not in torch_version and torch_version_ge("2.6.0")
# ) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")
_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = False
def _get_sharding_strategy(strategy: str) -> ShardingStrategy:
"""Helper function to convert sharding strategy strings to ShardingStrategy enum."""
return getattr(ShardingStrategy, strategy)
[docs]def is_distributed() -> bool:
"""Check if all environment variables required to initialize torch.distributed are set
and distributed is properly installed. This indicates a distributed run.
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
Checks the following conditions:
* torch.distributed is available
* master port and master address environment variables are set
* world size is >1
* rank environment variable is set
Returns:
bool: True if all of the above conditions hold, False otherwise.
"""
port = os.environ.get("MASTER_PORT", "")
addr = os.environ.get("MASTER_ADDR", "")
size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", -1))
avlb = dist.is_available()
return bool(port and addr and size >= 1 and rank >= 0 and avlb)
def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcasts a tensor from a source to all other processes.
Args:
tensor (torch.Tensor): torch.Tensor to broadcast.
src (int, optional): Source rank. Defaults to 0.
Returns:
torch.Tensor: Broadcasted tensor.
"""
if dist.is_available() and dist.is_initialized():
device = tensor.device
if dist.get_backend() == "nccl":
tensor = tensor.to(get_device("cuda"))
dist.broadcast(tensor, src=src, group=None)
return tensor.to(device)
else:
return tensor
[docs]def get_distributed_backend(device_type: str, offload_ops_to_cpu: bool = False) -> str:
"""Gets the PyTorch Distributed backend based on device type.
Args:
device_type (str): Device type to get backend for.
offload_ops_to_cpu (bool, optional): Flag to check if any operations should be offloaded to CPU.
Examples of these kinds of operations are CPU offload for FSDP and asynchronous save for distributed
checkpointing. Defaults to False.
Example:
>>> get_distributed_backend("cuda")
'nccl'
>>> get_distributed_backend("cpu")
'gloo'
>>> get_distributed_backend("cuda", offload_ops_to_cpu=True)
'cuda:nccl,cpu:gloo'
Returns:
str: Distributed backend for use in ``torch.distributed.init_process_group``.
"""
default_device_backend_map = dist.Backend.default_device_backend_map
backend = "nccl"
if device_type in default_device_backend_map:
backend = default_device_backend_map[device_type]
if offload_ops_to_cpu:
backend = f"{device_type}:{backend},cpu:gloo"
return backend
[docs]@deprecated(
msg="The functionality of `init_distributed` is covered by `torch.distributed.init_process_group`. "
)
def init_distributed(**kwargs: Dict[str, Any]) -> bool:
"""Initialize process group required for ``torch.distributed``.
Args:
**kwargs (Dict[str, Any]): Additional arguments to pass to torch.distributed.init_process_group.
Returns:
bool: True if torch.distributed is initialized.
Raises:
RuntimeError: If torch.distributed is already initialized.
"""
if is_distributed():
if dist.is_initialized():
raise RuntimeError("torch.distributed already initialized.")
dist.init_process_group(**kwargs)
return True
else:
return False
def set_torch_num_threads() -> None:
"""
Sets the number of threads used by torch to utilize all physical CPU
cores for intra-op parallelism. Currently, this function sets num_threads
to be the number of physical CPU cores divided by the number of GPUs as we
use one process per GPU, and this avoids CPU oversubscription. Note that this is
currently a rough approximation, and doesn't take into account environments where
things like CPU affinity is set.
"""
num_threads = os.cpu_count() // (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
torch.set_num_threads(num_threads)
_log.info(f"Set intra op parallelism no. of threads to {num_threads}")
@deprecated(
msg="`get_world_size_and_rank` will move to `torchtune.utils._device` in future releases. "
"Please use `torchtune.utils.get_world_size_and_rank` instead."
)
def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current process in the default process group.
Returns:
Tuple[int, int]: world size, rank
"""
if dist.is_available() and dist.is_initialized():
return torch.distributed.get_world_size(), torch.distributed.get_rank()
else:
return 1, 0
def validate_no_params_on_meta_device(model: nn.Module) -> None:
"""
Utility to validate that model has no params or buffers on meta device.
If a meta param or buffer is found, an error indicating the param name will
be raised.
Args:
model (nn.Module): model to check for meta params
Raises:
RuntimeError: If meta params or buffers exist in model
"""
for n, p in chain(model.named_parameters(), model.named_buffers()):
if p.is_meta:
raise RuntimeError(f"Unexpected param or buffer {n} on meta device.")
def load_from_full_model_state_dict(
model: "FSDPModule", # noqa
full_sd: Dict[str, Any],
device: torch.device,
strict: bool = False,
cpu_offload: bool = False,
) -> _IncompatibleKeys:
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
Args:
model (FSDPModule): Model to generate fully qualified names for cpu_state_dict
full_sd (Dict[str, Any]): a full state dict to load into the model
device (torch.device): device used to move full state dict tensors
strict (bool): flag to check if to load the model in strict mode
cpu_offload (bool): flag to check if offload to CPU is enabled
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Raises:
NotImplementedError: If got FSDP with more than 1D.
"""
# PyTorch nightly versions from December 20, 2024, support the following features:
# - `set_model_state_dict` with the `cpu_offload` option
# - Multiple devices in local state dict
# - Relative optimizations for improved memory performance
# Please keep the version check `_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE` until these changes are
# released in the PyTorch stable version.
has_nf4 = any(
hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor)
for param in model.parameters()
)
meta_sharded_sd = model.state_dict()
# NF4Tensor is not supported in `set_model_state_dict` right now, running with the previous logic right
# now, would support in the future and remove the following code
if _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE and not has_nf4:
for param_name in full_sd.keys():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_sd[param_name] = full_sd[param_name].to(sharded_meta_param.dtype)
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
strict=strict,
cpu_offload=cpu_offload,
)
return set_model_state_dict(
model=model, model_state_dict=full_sd, options=options
)
else:
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size
full_tensor = to_nf4(
full_tensor,
block_size=block_size,
scaler_block_size=scaler_block_size,
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(
f"only support 1D FSDP but got {mesh.ndim=}"
)
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[
shard_rank
]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)
# TODO: change to from_local API (need to add view support for NF4)
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
),
requires_grad=sharded_meta_param.requires_grad,
)
elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter:
"""
Manually gather NF4Tensor parameter since it does not support all_gather
"""
mesh = sharded_param.device_mesh
nf4_tensor = sharded_param._local_tensor
quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh)
full_quant_params = []
for quant_param in quant_params:
d0, *dn = quant_param.shape
shape = (d0 * mesh.get_group().size(), *dn)
full_quant_param = torch.empty(
shape, device=quant_param.device, dtype=quant_param.dtype
)
dist.all_gather_into_tensor(
full_quant_param, quant_param, mesh.get_group(), async_op=False
)
full_quant_params.append(full_quant_param)
full_param, _ = nf4_tensor.fsdp_post_all_gather(
full_quant_params, metadata, nf4_tensor.dtype
)
return full_param
[docs]def gather_cpu_state_dict(
model: "FSDPModule", # noqa
is_rank_zero: bool,
device: Optional[torch.device] = None,
adapter_weights_only: bool = False,
) -> Dict[str, Any]:
"""
Converting sharded state dict into a full state dict on CPU
Returning non-empty result only on rank0 to avoid peaking CPU memory
Currenltly we can used distributed state dict API to process model without NF4Tensor. Otherwise, we need to
manually gather any NF4 tensors until all-gather is supported in the NF4Tensor subclass
TODO: add support for NF4Tensor at distributed state dict API
Args:
model (FSDPModule): Model to generate fully qualified names for cpu_state_dict
is_rank_zero (bool): flag to check if the process is on rank 0
device (Optional[torch.device]): device to use for sharded tensors. Default: None
adapter_weights_only (bool): flag to check if only trainable parameters should be returned. Default: False
Returns:
Dict[str, Any]: State dict on CPU
"""
# TODO: Disabling DSD as it has issues. Add back changes in #2138 once DSD issue is fixed.
cpu_state_dict = {}
sharded_sd = model.state_dict()
for param_name, param in sharded_sd.items():
if param.is_cpu:
# Move back to device if offloaded to CPU
param = param.to(device)
if hasattr(param, "_local_tensor"):
if isinstance(param._local_tensor, NF4Tensor):
param = _gather_nf4_tensor(param)
else:
# Gather DTensor
param = param.full_tensor()
if isinstance(param, NF4Tensor):
# upcasting NF4 to original dtype
param = param.to(param.dtype)
if is_rank_zero:
cpu_state_dict[param_name] = param.cpu()
torch.distributed.barrier()
if adapter_weights_only:
cpu_state_dict = get_adapter_state_dict(cpu_state_dict, device=None)
return cpu_state_dict
def get_full_optimizer_state_dict(
model: "FSDPModule", # noqa
opt: Optimizer,
is_rank_zero: bool,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Converting optimizer state from sharded to full
For example, "exp_avg" in AdamW is `DTensor`,
"exp_avg.full_tensor()" converts it to plain tensor on rank 0
Returning non-empty cpu state dict on rank 0
"""
options = StateDictOptions(
full_state_dict=True, broadcast_from_rank0=True, cpu_offload=True
)
full_state_dict = get_optimizer_state_dict(
model=model, optimizers=opt, options=options
)
if is_rank_zero:
return full_state_dict
else:
return {}
def load_from_full_optimizer_state_dict(
model: "FSDPModule", # noqa
opt: Optimizer,
full_sd: Dict[str, Any],
device: torch.device,
) -> None:
"""
Converting full optimizer state to sharded state dict
and loading it into optimizer
"""
if _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE:
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
cpu_offload=device is torch.device("cpu"),
)
set_optimizer_state_dict(
model=model, optimizers=opt, optim_state_dict=full_sd, options=options
)
else:
PARAMS = "params" # noqa: N806
_init_optim_state(opt)
param_groups = opt.state_dict()["param_groups"]
state = opt.state_dict()["state"]
full_param_groups = full_sd["param_groups"]
full_state = full_sd["state"]
for param_group, full_param_group in zip(param_groups, full_param_groups):
for key, value in full_param_group.items():
if key == PARAMS:
continue
param_group[key] = value
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
# exp_avg is DTensor
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
# step is plain tensor
param_state[attr] = full_tensor
opt.load_state_dict(
{
"param_groups": param_groups,
"state": state,
}
)
def get_shard_conditions(
name: str,
module: nn.Module,
names_to_match: Optional[List[str]] = None,
*args,
**kwargs,
) -> bool:
"""
Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise,
returns False. This is a helper function for sharding a model with FSDP.
In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules
and apply fully_shard using this condition.
As part of our sharding strategy, we want each layer to be sharded separately, as this is
generally efficient. We may also want to shard certain modules that are not layers, such as
the embedding module.
#TODO: a more robust way would be to shard on the module type, not the name.
Args:
name (str): Name of the module.
module (nn.Module): Module to be sharded.
names_to_match (Optional[List[str]]): List of names to match, if any.
*args: Variable length argument list to be passed to the Embedding module.
**kwargs: Arbitrary keyword arguments to be passed to the Embedding module.
Returns:
bool: True if the module name matches the condition, False otherwise.
Examples:
>>> names_to_match = ["embedding"]
>>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention",
"my_wrapper.layer.1.something", "embedding"]
>>> matches = []
>>> for name in layer_names:
>>> if shard_condition_is_layer_or_match(name, None): matches.append(name)
>>> print(matches)
>>> ["layers.0", "decoder.layers.1", "embedding"]
"""
if names_to_match and name in names_to_match:
return True
name_list = name.split(".")
if len(name_list) >= 2:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])
return False
def shard_model(
model: TransformerDecoder,
shard_conditions: List[Callable[[str, nn.Module], bool]],
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
dp_mesh: Optional[DeviceMesh] = None,
) -> None:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
This method will over the model's named modules from the bottom-up and apply shard modules
based on whether they meet any of the criteria from shard_conditions.
Args:
model (TransformerDecoder): Model to shard with FSDP.
shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
which modules to shard with FSDP. Each function should take module name (relative to root)
and the module itself, returning True if FSDP should shard the module and False otherwise.
If any of shard_conditions return True for a given module, it will be sharded by FSDP.
cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer
states to CPU.
reshard_after_forward (bool): Whether to reshard parameters and buffers after
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under mutliple parallelism.
Default to None.
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward, "mesh": dp_mesh}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)
def prepare_mha_for_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
) -> nn.Module:
"""
Utility to scale MultiHeadAttention parameters(num_heads, num_kv_heads, embed_dim) across
tensor parallel devices. Each device will handle a portion of the attention computations.
Args:
model (nn.Module): Model whose attention parameters will be scaled by TP size.
tp_mesh (DeviceMesh): Tensor parallel device mesh.
Returns:
nn.Module: The model with scaled MultiHeadAttention parameters.
Raises:
ValueError: If attention heads, kv heads, or embed dimension is not divisible by TP size.
Examples:
>>> from torchtune.modules import TransformerDecoder
>>> from torch.distributed.device_mesh import DeviceMesh
>>> model = TransformerDecoder(
num_heads=32,
num_kv_heads=32,
embed_dim=4096,
)
>>> tp_mesh = DeviceMesh("cuda", torch.arange(2)) # 2 GPUs
>>> model = prepare_mha_for_tp(model, tp_mesh)
>>> # Now each GPU has:
>>> # num_heads = 16 (32/2)
>>> # num_kv_heads = 16 (32/2)
>>> # embed_dim = 2048 (4096/2)
"""
# Handle fusion models by extracting decoder
is_fusion_model = isinstance(model, (DeepFusionModel, EarlyFusionModel))
decoder = model.decoder if is_fusion_model else model
tp_size = tp_mesh.size()
for m in list(decoder.modules()):
if isinstance(m, MultiHeadAttention):
# Adjust attention module to use the local number of heads
if m.num_heads % tp_size != 0:
raise ValueError(
f"Number of attention heads ({m.num_heads}) must be divisible by "
f"tensor parallel size ({tp_size})."
)
if m.num_kv_heads % tp_size != 0:
raise ValueError(
f"Number of KV heads ({m.num_kv_heads}) must be divisible by "
f"tensor parallel size ({tp_size})."
)
if m.embed_dim % tp_size != 0:
raise ValueError(
f"Embedding dimension ({m.embed_dim}) must be divisible by "
f"tensor parallel size ({tp_size})."
)
m.num_heads = m.num_heads // tp_size
m.num_kv_heads = m.num_kv_heads // tp_size
m.embed_dim = m.embed_dim // tp_size
if is_fusion_model:
model.decoder = decoder
return model