Checkpointing¶
This module implements methods for checkpointing and resuming training from a checkpoint.
- class torchft.checkpointing.CheckpointServer(state_dict: Callable[[], T])[source]¶
Bases:
Generic
[T
]This is an HTTP server that can be used to transfer checkpoints between workers.
This allows for fast recovery of workers by fetching the current weights from an existing worker.
- Parameters:
state_dict – a callable that returns the state dict to be transferred
- address() str [source]¶
Returns the HTTP address to fetch a checkpoint from this server at the current step.
Format: http://host:port/checkpoint/1234
- Returns:
an HTTP address
- allow_checkpoint(step: int) None [source]¶
Allows serving the checkpoint with the specified step number.
- Parameters:
step – the step number to serve
- disallow_checkpoint() None [source]¶
Disallows serving the checkpoint.
All requests will block until allow_checkpoint is called.