Shortcuts

FullModelHFCheckpointer

class torchtune.training.FullModelHFCheckpointer(checkpoint_dir: str, checkpoint_files: Union[List[str], Dict[str, str]], model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, safe_serialization: bool = True, should_load_recipe_state: bool = False)[source]

Checkpointer which reads and writes checkpoints in HF’s format. For LoRA models this includes saving checkpoints in a format that can be loaded into PEFT via e.g. from_pretrained. Examples include the Llama-2-7b-hf model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b-hf).

Note

HF checkpoint names are usually ordered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure we read the files in the right order, we sort the checkpoint file names before reading.

Note

Checkpoint conversion to and from HF’s format requires access to model params which are read directly from the config.json file. This helps ensure we either load the weights correctly or error out in case of discrepancy between the HF checkpoint file and torchtune’s model implementations.

Parameters:
  • checkpoint_dir (str) – Directory containing the checkpoint files

  • checkpoint_files (Union[List[str], Dict[str, str]]) – List of checkpoint files to load or a dictionary containing the keys keys [“filename_format”, “max_filename”]. Since the checkpointer takes care of sorting by file ID, the order in this list does not matter.

  • model_type (str) – Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.

  • output_dir (str) – Directory to save the checkpoint files

  • adapter_checkpoint (Optional[str]) – Path to the adapter weights. If None, and should_load_recipe_state=True, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None.

  • recipe_checkpoint (Optional[str]) – Path to the recipe state checkpoint file. If None, and should_load_recipe_state=True, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. Default is None.

  • resume_from_checkpoint (bool) – If True, the checkpointer will load the additional checkpoint files corresponding to the receipe state from a previous run. Default is False. This flag is deprecated. Please use the should_load_recipe_state flag instead.

  • safe_serialization (bool) – If True, the checkpointer will save the checkpoint file using safetensors. Default is True.

  • should_load_recipe_state (bool) – If True, the checkpointer will load the additional checkpoint files corresponding to the receipe state from a previous run. Default is False

load_checkpoint() Dict[str, Any][source]

Load HF checkpoint from file.

The keys and weights from across all checkpoint files are merged into a single state_dict. We preserve the “state_dict key” <-> “checkpoint file” mapping in weight_map so we can write the state dict correctly in save_checkpoint.

Before returning, the model state dict is converted to a torchtune-compatible format using the appropriate convert_weights function (depending on self._model_type).

Returns:

torchtune checkpoint state dict

Return type:

state_dict (Dict[str, Any])

Raises:

ValueError – If the values in the input state_dict are not Tensors

save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None[source]

Save HF checkpoint to file. If intermediate_checkpoint is True, an additional checkpoint file recipe_state.pt is created in _output_dir/RECIPE_STATE_DIRNAME which contains the recipe state.

The state_dict is first converted back to the HF format and then partitioned based on the _weight_map into separate checkpoint files.

Parameters:
  • state_dict (Dict[str, Any]) – Checkpoint state dict to be written out to file

  • epoch (int) – Epoch number. Used to create the checkpoint file name

  • intermediate_checkpoint (bool) – If True, an additional checkpoint files for recipe state and (if applicable) adapter weights are created. Default is False

  • adapter_only (bool) – If True, only save the adapter weights. Default is False

Raises:

ValueError – if adapter_only is True and adapter checkpoint not found in state_dict.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources