Source code for torch.distributed.checkpoint.staging
from typing import Optional, runtime_checkable
from typing_extensions import Protocol
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
__all__ = ["AsyncStager", "BlockingAsyncStager"]
[docs]@runtime_checkable
class AsyncStager(Protocol):
"""
This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users
to customize how data is staged previous to executing the usual dcp.save path in parallel.
The expected order of operations (concretely defined in `torch.distributed.state_dict_saver.async_save`)
is the following:
1. AsyncStager.stage_data(state_dict):
This call gives the AsyncStager the opportunity to 'stage'
the state_dict. The expectation and purpose of staging in this context is to create a "training-safe"
representation of the state dict, meaning that any updates to module data after staging is complete
should not be reflected in the state dict returned from this method. For example, in the default
case a copy of the entire state dict is created on CPU RAM and returned here, allowing users
to continue training without risking changes to data which is being serialized.
2. dcp.save is called on the state_dict returned from stage in parallel. This call is responsible
for serializing the state_dict and writing it to storage.
3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after
the serialization thread starts and before returning from dcp.async_save. If this is set to False,
the assumption is the user has defined a custom synchronization point for the the purpose of further
optimizing save latency in the training loop (for example, by overlapping staging with the
forward/backward pass), and it is the respondsibility of the user to call `AsyncStager.synchronize_staging`
at the appropriate time.
"""
# default to True since the common case is to stage synchronously
_synchronize_after_execute: bool = True
@property
def should_synchronize_after_execute(self) -> bool:
"""
Whether to synchronize after executing the stage.
"""
return self._synchronize_after_execute
[docs] def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""
Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is
innoculated from any updates incurred after the stage call is complete.
"""
raise NotImplementedError(
f"{self.__class__.__name__} must implement stage method"
)
[docs] def synchronize_staging(self) -> None:
"""
In the case `stage` is async in some way, this method should be called to ensure staging
is complete and it is safe to begin modifying the original `state_dict`
"""
[docs]class BlockingAsyncStager(AsyncStager):
"""
An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete.
This implementation also provides an option to optimize stage latency using pinned memory.
N.B. synchronize_staging is a no-op in this case.
"""
# default to True since the common case is to stage synchronously
_synchronize_after_execute: bool = False
def __init__(
self,
cache_staged_state_dict: bool = False,
type_check: bool = False,
):
"""
Initializes the BlockingAsyncStager.
Args:
cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
type_check: Whether to perform a type check during cpu_offload. Defaults to False.
"""
self.cache_staged_state_dict = cache_staged_state_dict
self.type_check = type_check
self.state_dict_cache: Optional[STATE_DICT_TYPE] = None
[docs] def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""
Returns a copy of `state_dict` on the CPU.
"""
if not self.cache_staged_state_dict:
staged_state_dict = _create_cpu_state_dict(state_dict)
_copy_state_dict(state_dict, staged_state_dict, type_check=self.type_check)
return staged_state_dict
if self.state_dict_cache is None:
self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True)
return _copy_state_dict(state_dict, self.state_dict_cache)