Shortcuts

Checkpointing

This module implements methods for checkpointing and resuming training from a checkpoint.

class torchft.checkpointing.CheckpointTransport[source]

Bases: Generic[T], ABC

disallow_checkpoint() None[source]

Called after send_checkpoint to wait for the checkpoint to be sent.

Once this returns, the state_dict may be mutated so no further data should be sent.

abstract metadata() str[source]

Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.

abstract recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]

Receives the checkpoint from the given rank.

Parameters:
  • src_rank – the rank to receive the checkpoint from

  • metadata – the metadata returned by the remote CheckpointTransport

  • step – the step number to receive

  • timeout – the timeout to wait for the checkpoint

abstract send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]

Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

Parameters:
  • dst_ranks – the ranks to send to

  • step – the step number to send

  • state_dict – the state dict to send

  • timeout – the timeout to wait for the checkpoint to be sent

shutdown(wait: bool = True) None[source]

Called to shutdown the checkpoint transport.

Parameters:

wait – whether to wait for the transport to shutdown

class torchft.checkpointing.HTTPTransport(timeout: timedelta, num_chunks: int)[source]

Bases: CheckpointTransport[T]

This is an HTTP server that can be used to transfer checkpoints between workers.

This allows for fast recovery of workers by fetching the current weights from an existing worker.

Parameters:
  • timeout – the timeout for HTTP requests

  • num_chunks – the number of chunks to split the checkpoint into (0 for no chunking)

address() str[source]

Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address.

Format: http://host:port/checkpoint/1234

Returns:

an HTTP address

allow_checkpoint(step: int) None[source]

Allows serving the checkpoint with the specified step number.

Parameters:

step – the step number to serve

disallow_checkpoint() None[source]

Disallows serving the checkpoint.

All requests will block until allow_checkpoint is called.

metadata() str[source]

Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.

recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]

Receives the checkpoint from the given rank.

Parameters:
  • src_rank – the rank to receive the checkpoint from

  • metadata – the metadata returned by the remote CheckpointTransport

  • step – the step number to receive

  • timeout – the timeout to wait for the checkpoint

send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]

Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

Parameters:
  • dst_ranks – the ranks to send to

  • step – the step number to send

  • state_dict – the state dict to send

  • timeout – the timeout to wait for the checkpoint to be sent

shutdown(wait: bool = True) None[source]

Shutdown the server.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources