FullModelHFCheckpointer¶
- class torchtune.training.FullModelHFCheckpointer(checkpoint_dir: str, checkpoint_files: Union[List[str], Dict[str, str]], model_type: ModelType, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, safe_serialization: bool = False)[source]¶
Checkpointer which reads and writes checkpoints in HF’s format. For LoRA models this includes saving checkpoints in a format that can be loaded into PEFT via e.g.
from_pretrained
. Examples include the Llama-2-7b-hf model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b-hf).Note
HF checkpoint names are usually ordered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure we read the files in the right order, we sort the checkpoint file names before reading.
Note
Checkpoint conversion to and from HF’s format requires access to model params which are read directly from the
config.json
file. This helps ensure we either load the weights correctly or error out in case of discrepancy between the HF checkpoint file and torchtune’s model implementations.- Parameters:
checkpoint_dir (str) – Directory containing the checkpoint files
checkpoint_files (Union[List[str], Dict[str, str]]) – List of checkpoint files to load. Since the checkpointer takes care of sorting by file ID, the order in this list does not matter. TODO: update this
model_type (ModelType) – Model type of the model for which the checkpointer is being loaded
output_dir (str) – Directory to save the checkpoint files
adapter_checkpoint (Optional[str]) – Path to the adapter weights. Default is None
recipe_checkpoint (Optional[str]) – Path to the recipe state checkpoint file. Default is None
resume_from_checkpoint (bool) – If True, the checkpointer will load the additional checkpoint files to resume training from a previous run. Default is False
safe_serialization (bool) – If True, the checkpointer will save the checkpoint file using safetensors
- Raises:
ValueError – If
resume_from_checkpoint
is True butrecipe_checkpoint
is None
- load_checkpoint() Dict[str, Any] [source]¶
Load HF checkpoint from file.
The keys and weights from across all checkpoint files are merged into a single state_dict. We preserve the “state_dict key” <-> “checkpoint file” mapping in weight_map so we can write the state dict correctly in
save_checkpoint
.Before returning, the model state dict is converted to a torchtune-compatible format using the appropriate convert_weights function (depending on
self._model_type
).- Returns:
torchtune checkpoint state dict
- Return type:
state_dict (Dict[str, Any])
- Raises:
ValueError – If the values in the input state_dict are not Tensors
- save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None [source]¶
Save HF checkpoint to file. If
intermediate_checkpoint
is True, an additional checkpoint filerecipe_state.pt
is created in_output_dir
which contains the recipe state.The state_dict is first converted back to the HF format and then partitioned based on the
_weight_map
into separate checkpoint files.- Parameters:
state_dict (Dict[str, Any]) – Checkpoint state dict to be written out to file
epoch (int) – Epoch number. Used to create the checkpoint file name
intermediate_checkpoint (bool) – If True, an additional checkpoint files for recipe state and (if applicable) adapter weights are created. Default is False
adapter_only (bool) – If True, only save the adapter weights. Default is False
- Raises:
ValueError – if
adapter_only
is True and adapter checkpoint not found in state_dict.