FullModelTorchTuneCheckpointer¶
- class torchtune.training.FullModelTorchTuneCheckpointer(checkpoint_dir: str, checkpoint_files: List[str], model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, should_load_recipe_state: bool = False)[source]¶
Checkpointer which reads and writes checkpoints in a format compatible with torchtune. No conversion of weights is required.
Currently this supports reading a single checkpoint file only. This will likely change as we add support for larger models.
- Parameters:
checkpoint_dir (str) – Directory containing the checkpoint files
checkpoint_files (List[str]) – List of checkpoint files to load. Since the checkpointer takes care of sorting by file ID, the order in this list does not matter
model_type (str) – Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str) – Directory to save the checkpoint files
adapter_checkpoint (Optional[str]) – Path to the adapter weights. If None, and should_load_recipe_state=True, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None.
recipe_checkpoint (Optional[str]) – Path to the recipe state checkpoint file. If None, and should_load_recipe_state=True, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. Default is None.
resume_from_checkpoint (bool) – If True, the checkpointer will load the additional checkpoint files corresponding to the recipe state from a previous run. Default is False. This flag is deprecated. Please use the should_load_recipe_state flag instead.
should_load_recipe_state (bool) – If True, the checkpointer will load the additional checkpoint files corresponding to the recipe state from a previous run. Default is False
- Raises:
ValueError – If more than one checkpoint file is provided
- load_checkpoint(weights_only: bool = True) Dict[str, Any] [source]¶
Load torchtune checkpoint from file. Currently only loading from a single file is supported.
The output state_dict has the following format, with keys other than “model” only present if
should_load_recipe_state
is True:>>> { >>> "model": { >>> "key_1": weight >>> ... >>> }, >>> "optimizer": {...}, >>> ... >>> }
- save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None [source]¶
Save torchtune checkpoint to file. If
intermediate_checkpoint
is True, an additional checkpoint filerecipe_state.pt
is created in_output_dir/RECIPE_STATE_DIRNAME
which contains the recipe state. The output state dicts have the following formats:>>> # Model >>> { >>> "key_1": weight >>> ... >>> } >>> >>> # Recipe state >>> { >>> "optimizer": ..., >>> "epoch": ..., >>> ... >>> }
- Parameters:
state_dict (Dict[str, Any]) – State dict with model and (optionally) recipe state
epoch (int) – Current epoch number. This is added to the checkpoint file name to ensure we’re not overwriting intermediate checkpoint files
intermediate_checkpoint (bool) – If True, save an additional checkpoint file with the recipe state
adapter_only (bool) – If True, only save the adapter weights. Default is False
- Raises:
ValueError – if
adapter_only
is True and adapter checkpoint not found in state_dict.