Shortcuts

Checkpointing

This module implements methods for checkpointing and resuming training from a checkpoint.

class torchft.checkpointing.CheckpointServer(state_dict: Callable[[], T])[source]

Bases: Generic[T]

This is an HTTP server that can be used to transfer checkpoints between workers.

This allows for fast recovery of workers by fetching the current weights from an existing worker.

Parameters:

state_dict – a callable that returns the state dict to be transferred

address() str[source]

Returns the HTTP address to fetch a checkpoint from this server at the current step.

Format: http://host:port/checkpoint/1234

Returns:

an HTTP address

allow_checkpoint(step: int) None[source]

Allows serving the checkpoint with the specified step number.

Parameters:

step – the step number to serve

disallow_checkpoint() None[source]

Disallows serving the checkpoint.

All requests will block until allow_checkpoint is called.

classmethod load_from_address(address: str) T[source]

Loads a checkpoint from the given address.

Parameters:

address – the HTTP address to load the checkpoint from

shutdown() None[source]

Shutdown the server.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources