Skip to main content
Blog

Reducing Storage Footprint and Bandwidth Usage for Distributed Checkpoints with PyTorch DCP

Summary

PyTorch Distributed Checkpointing (DCP) is a versatile and powerful tool for managing model checkpoints in distributed training environments. Its modular design empowers developers to tailor its components to their specific requirements, making it an ideal solution for a wide range of use cases.

In this blog post, we’ll showcase how we leveraged PyTorch DCP’s modularity to integrate compression and achieve a 22% reduction in checkpoint size. We’ll also provide a deep dive into the implementation details of our customization, offering practical insights and guidance on how you can apply similar techniques to optimize your own checkpointing workflows and improve overall efficiency.

Motivation

Large Distributed Checkpoints

As models increase in complexity and size, distributed checkpointing becomes a critical component of the training process. However, these checkpoints often result in substantial storage demands and elevated bandwidth costs due to their large sizes.

Compression

To address this challenge, compression emerges as a natural solution. Given that checkpoints primarily consist of binary data (tensors), we aimed for an optimal compression ratio with minimal compression overhead. We chose the zstd compression algorithm for its efficiency and effectiveness.

DCP

The modular design of DCP, featuring well-defined and easily extensible components, made it an ideal choice as our checkpointing solution.

Details

Customizing StorageWriter

PyTorch DCP’s StorageWriter component is responsible for writing checkpoint data to storage. We customized this component by modifying _FileSystemWriter, which extends the base StorageWriter class. The _FileSystemWriter class now takes an additional parameter _extension, which is an instance of StreamTransformExtension.

def save(
    state_dict: STATE_DICT_TYPE,
    *,
    checkpoint_id: Union[str, os.PathLike, None] = None,
    # We used a _FileSystemWriterextended as a storage writer component
    storage_writer: Optional[StorageWriter] = None, 
    planner: Optional[SavePlanner] = None,
    process_group: Optional[dist.ProcessGroup] = None,
    no_dist: bool = False,
) -> Metadata:

class _FileSystemWriter(StorageWriter):

    def __init__(
        self,
        path: Union[str, os.PathLike],
        single_file_per_rank: bool = True,
        sync_files: bool = True,
        thread_count: int = 1,
        per_thread_copy_ahead: int = 10_000_000,
        overwrite: bool = True,
 # We customized _FileSystemWriterextended to take in an extension
        _extensions: Optional[Sequence[StreamTransformExtension]] = None,
        serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
        *args: Any,
        **kwargs: Any,
    ) -> None:

StreamTransformExtension is an abstract class that defines two methods: transform_to(), which is called on an output stream, and transform_from(), which is called on an input stream. These enable us to perform custom transformations on the stream data.

class StreamTransformExtension(Extension):

    @abc.abstractmethod
    def transform_to(self, output: IO[bytes]) -> IO[bytes]:

    @abc.abstractmethod
    def transform_from(self, input: IO[bytes]) -> IO[bytes]:

Implementing ZStandard Compression

We implemented a concrete subclass of StreamTransformExtension called ZStandard, which provides compression functionality using the zstd compression algorithm. Our ZStandard class implements the transform_to() to compress the outgoing stream data and the transform_from() to decompress the incoming stream data.

class ZStandard(StreamTransformExtension):

    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
# Our compression implementation

    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
# Our decompression implementation

Combining Customizations

Finally, we combined our custom _FileSystemWriter class with the ZStandard compression extension while saving the checkpoint. We wrote a sample test to demonstrate how everything comes together

fs_writer = FileSystemWriter(
          path=path,
          thread_count=thread_count,
         _extensions=[ZStandard()],
)

save(
         state_dict=state_dict_to_save,
         storage_writer=fs_writer,
)

Evaluation

Results

In collaboration with IBM, we conducted an evaluation of our proposed solution on one of their internal training clusters. The results showed a significant 22% reduction in checkpoint sizes, albeit at the cost of increased compression time. However, with multi-threading, we were able to mitigate this trade-off and limit the increase in checkpointing time to just 9%. This demonstrates the potential of our solution to strike a balance between checkpoint size reduction and performance.

Model Threads per Rank DCP Checkpoint Size (in GB) Checkpointing Time (s)
Baseline ZStd 𝚫 Baseline ZStd 𝚫
granite-3b-code-instruct 8 6.72 5.26 -21.8% 1.96 2.15 9.7%
4 6.72 5.26 -21.8% 1.98 2.38 20.2%
1 6.72 5.26 -21.8% 2.34 3.86 64.9%
granite-3.2-8b-instruct 8 15.6 12.08 –22.5% 3.37 3.65 8.3%
4 15.6 12.08 –22.5% 3.72 4.37 17.5%
1 15.6 12.08 –22.5% 5.37 8.45 57.4%

Setup

We chose two of IBM’s open sourced models (Granite-3B-Code-Instruct-128K and Granite-3.2-8B-Instruct). For evaluation, we perform full-parameter FSDP fine-tuning on these models with the Alpaca dataset on IBM’s Vela AI supercomputer, which is housed in IBM cloud. Each of Vela’s nodes has eight 80GB A100 GPUs, which are connected to each other by NVLink and NVSwitch. In addition, each node has two 2nd Generation Intel Xeon Scalable processors (Cascade Lake) and 1.5TB of DRAM. We provision one node of Vela with the following resources:

Testbed

  • Openshift 4.14 Cluster
  • Pod: 64 Intel Cascade Lake CPU cores, 800GB host memory, 8 x A100-80GB GPUs
  • Storage options exposed as persistent volumes:
    • 1TB local GPFS
    • S3 bucket

Workload

  • Full-parameter FSDP finetuning with checkpointing every epoch

Checkpointing configuration

  • save_state_dict() to storage
  • 1 to 8 threads per rank
  • 1 file per rank
  • 8 ranks

Conclusion

PyTorch DCP’s modular design empowers developers to tailor its components to specific use cases, unlocking new levels of customization and extensibility. By customizing the StorageWriter component and implementing a compression extension, we achieved significant checkpoint size reductions, leading to lower storage requirements, and reduced bandwidth costs.

We invite you to explore the vast possibilities of PyTorch DCP customization by diving into our documentation and experimenting with various extensions and modifications. Join the conversation on PyTorch GitHub and connect with the PyTorch Checkpointing team (open GitHub issue with label “oncall: distributed checkpointing”) to share your experiences, ask questions, and stay up-to-date on the latest developments!