

This module implements methods for checkpointing and resuming training from a checkpoint.

class torchft.checkpointing.CheckpointTransport[source]

Bases: Generic[T], ABC

disallow_checkpoint() None[source]

Called after send_checkpoint to wait for the checkpoint to be sent.

Once this returns, the state_dict may be mutated so no further data should be sent.

abstract metadata() str[source]

Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.

abstract recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]

Receives the checkpoint from the given rank.

  • src_rank – the rank to receive the checkpoint from

  • metadata – the metadata returned by the remote CheckpointTransport

  • step – the step number to receive

  • timeout – the timeout to wait for the checkpoint

abstract send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]

Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

  • dst_ranks – the ranks to send to

  • step – the step number to send

  • state_dict – the state dict to send

  • timeout – the timeout to wait for the checkpoint to be sent

shutdown(wait: bool = True) None[source]

Called to shutdown the checkpoint transport.


wait – whether to wait for the transport to shutdown

class torchft.checkpointing.HTTPTransport(timeout: timedelta, num_chunks: int)[source]

Bases: CheckpointTransport[T]

This is an HTTP server that can be used to transfer checkpoints between workers.

This allows for fast recovery of workers by fetching the current weights from an existing worker.

  • timeout – the timeout for HTTP requests

  • num_chunks – the number of chunks to split the checkpoint into (0 for no chunking)

address() str[source]

Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address.

Format: http://host:port/checkpoint/1234


an HTTP address

allow_checkpoint(step: int) None[source]

Allows serving the checkpoint with the specified step number.


step – the step number to serve

disallow_checkpoint() None[source]

Disallows serving the checkpoint.

All requests will block until allow_checkpoint is called.

metadata() str[source]

Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.

recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]

Receives the checkpoint from the given rank.

  • src_rank – the rank to receive the checkpoint from

  • metadata – the metadata returned by the remote CheckpointTransport

  • step – the step number to receive

  • timeout – the timeout to wait for the checkpoint

send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]

Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

  • dst_ranks – the ranks to send to

  • step – the step number to send

  • state_dict – the state dict to send

  • timeout – the timeout to wait for the checkpoint to be sent

shutdown(wait: bool = True) None[source]

Shutdown the server.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources