Shortcuts

FullModelTorchTuneCheckpointer

class torchtune.training.FullModelTorchTuneCheckpointer(checkpoint_dir: str, checkpoint_files: List[str], model_type: ModelType, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False)[source]

Checkpointer which reads and writes checkpoints in a format compatible with torchtune. No conversion of weights is required.

Currently this supports reading a single checkpoint file only. This will likely change as we add support for larger models.

Parameters:
  • checkpoint_dir (str) – Directory containing the checkpoint files

  • checkpoint_files (List[str]) – List of checkpoint files to load. Since the checkpointer takes care of sorting by file ID, the order in this list does not matter

  • model_type (ModelType) – Model type of the model for which the checkpointer is being loaded

  • output_dir (str) – Directory to save the checkpoint files

  • adapter_checkpoint (Optional[str]) – Path to the adapter weights. Default is None

  • recipe_checkpoint (Optional[str]) – Path to the recipe state checkpoint file. Default is None

  • resume_from_checkpoint (bool) – If True, the checkpointer will load the additional checkpoint files to resume training from a previous run. Default is False

Raises:
  • ValueError – If more than one checkpoint file is provided

  • ValueError – If the checkpoint file does not have a .pt extension

  • ValueError – If resume_from_checkpoint is True but recipe_checkpoint is None

load_checkpoint(weights_only: bool = True) Dict[str, Any][source]

Load torchtune checkpoint from file. Currently only loading from a single file is supported.

The output state_dict has the following format, with keys other than “model” only present if resume_from_checkpoint is True:

>>>     {
>>>         "model": {
>>>             "key_1": weight
>>>             ...
>>>         },
>>>         "optimizer": {...},
>>>         ...
>>>     }
Parameters:

weights_only (bool) – flag passed down to torch.load. We expose this, because quantized models cannot be loaded with weights_only=True

Returns:

state_dict from the input checkpoint

Return type:

Dict[str, Any]

save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None[source]

Save torchtune checkpoint to file. If intermediate_checkpoint is True, an additional checkpoint file recipe_state.pt is created in _output_dir which contains the recipe state. The output state dicts have the following formats:

>>> # Model
>>> {
>>>     "key_1": weight
>>>     ...
>>> }
>>>
>>> # Recipe state
>>> {
>>>     "optimizer": ...,
>>>     "epoch": ...,
>>>     ...
>>> }
Parameters:
  • state_dict (Dict[str, Any]) – State dict with model and (optionally) recipe state

  • epoch (int) – Current epoch number. This is added to the checkpoint file name to ensure we’re not overwriting intermediate checkpoint files

  • intermediate_checkpoint (bool) – If True, save an additional checkpoint file with the recipe state

  • adapter_only (bool) – If True, only save the adapter weights. Default is False

Raises:

ValueError – if adapter_only is True and adapter checkpoint not found in state_dict.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources