Checkpointing¶
This module implements methods for checkpointing and resuming training from a checkpoint.
- class torchft.checkpointing.CheckpointTransport[source]¶
Bases:
Generic
[T
],ABC
- disallow_checkpoint() None [source]¶
Called after send_checkpoint to wait for the checkpoint to be sent.
Once this returns, the state_dict may be mutated so no further data should be sent.
- abstract metadata() str [source]¶
Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
- abstract recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T [source]¶
Receives the checkpoint from the given rank.
- Parameters:
src_rank – the rank to receive the checkpoint from
metadata – the metadata returned by the remote CheckpointTransport
step – the step number to receive
timeout – the timeout to wait for the checkpoint
- abstract send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None [source]¶
Sends the checkpoint, only called when there is a rank that is behind.
This may be async.
- Parameters:
dst_ranks – the ranks to send to
step – the step number to send
state_dict – the state dict to send
timeout – the timeout to wait for the checkpoint to be sent
- class torchft.checkpointing.HTTPTransport(timeout: timedelta, num_chunks: int)[source]¶
Bases:
CheckpointTransport
[T
]This is an HTTP server that can be used to transfer checkpoints between workers.
This allows for fast recovery of workers by fetching the current weights from an existing worker.
- Parameters:
timeout – the timeout for HTTP requests
num_chunks – the number of chunks to split the checkpoint into (0 for no chunking)
- address() str [source]¶
Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address.
Format: http://host:port/checkpoint/1234
- Returns:
an HTTP address
- allow_checkpoint(step: int) None [source]¶
Allows serving the checkpoint with the specified step number.
- Parameters:
step – the step number to serve
- disallow_checkpoint() None [source]¶
Disallows serving the checkpoint.
All requests will block until allow_checkpoint is called.
- metadata() str [source]¶
Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
- recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T [source]¶
Receives the checkpoint from the given rank.
- Parameters:
src_rank – the rank to receive the checkpoint from
metadata – the metadata returned by the remote CheckpointTransport
step – the step number to receive
timeout – the timeout to wait for the checkpoint
- send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None [source]¶
Sends the checkpoint, only called when there is a rank that is behind.
This may be async.
- Parameters:
dst_ranks – the ranks to send to
step – the step number to send
state_dict – the state dict to send
timeout – the timeout to wait for the checkpoint to be sent