FullModelTorchTuneCheckpointer¶
- class torchtune.training.FullModelTorchTuneCheckpointer(checkpoint_dir: str, checkpoint_files: List[str], model_type: ModelType, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False)[source]¶
Checkpointer which reads and writes checkpoints in a format compatible with torchtune. No conversion of weights is required.
Currently this supports reading a single checkpoint file only. This will likely change as we add support for larger models.
- Parameters:
checkpoint_dir (str) – Directory containing the checkpoint files
checkpoint_files (List[str]) – List of checkpoint files to load. Since the checkpointer takes care of sorting by file ID, the order in this list does not matter
model_type (ModelType) – Model type of the model for which the checkpointer is being loaded
output_dir (str) – Directory to save the checkpoint files
adapter_checkpoint (Optional[str]) – Path to the adapter weights. Default is None
recipe_checkpoint (Optional[str]) – Path to the recipe state checkpoint file. Default is None
resume_from_checkpoint (bool) – If True, the checkpointer will load the additional checkpoint files to resume training from a previous run. Default is False
- Raises:
ValueError – If more than one checkpoint file is provided
ValueError – If the checkpoint file does not have a .pt extension
ValueError – If
resume_from_checkpoint
is True butrecipe_checkpoint
is None
- load_checkpoint(weights_only: bool = True) Dict[str, Any] [source]¶
Load torchtune checkpoint from file. Currently only loading from a single file is supported.
The output state_dict has the following format, with keys other than “model” only present if
resume_from_checkpoint
is True:>>> { >>> "model": { >>> "key_1": weight >>> ... >>> }, >>> "optimizer": {...}, >>> ... >>> }
- save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None [source]¶
Save torchtune checkpoint to file. If
intermediate_checkpoint
is True, an additional checkpoint filerecipe_state.pt
is created in_output_dir
which contains the recipe state. The output state dicts have the following formats:>>> # Model >>> { >>> "key_1": weight >>> ... >>> } >>> >>> # Recipe state >>> { >>> "optimizer": ..., >>> "epoch": ..., >>> ... >>> }
- Parameters:
state_dict (Dict[str, Any]) – State dict with model and (optionally) recipe state
epoch (int) – Current epoch number. This is added to the checkpoint file name to ensure we’re not overwriting intermediate checkpoint files
intermediate_checkpoint (bool) – If True, save an additional checkpoint file with the recipe state
adapter_only (bool) – If True, only save the adapter weights. Default is False
- Raises:
ValueError – if
adapter_only
is True and adapter checkpoint not found in state_dict.