• Docs >
  • Checkpointing in torchtune
Shortcuts

Checkpointing in torchtune

This deep-dive will walk you through the design and behavior of the checkpointer and associated utilities.

What this deep-dive will cover:
  • Checkpointer design for torchtune

  • Checkpoint formats and how we handle them

  • Checkpointing scenarios: Intermediate vs Final and LoRA vs Full-finetune

Overview

torchtune checkpointers are designed to be composable components which can be plugged into any recipe - training, evaluation or generation. Each checkpointer supports a set of models and scenarios making these easy to understand, debug and extend.

Before we dive into the checkpointer in torchtune, let’s define some concepts.


Checkpoint Format

In this deep-dive, we’ll talk about different checkpoint formats and how torchtune handles them. Let’s take a close look at these different formats.

Very simply put, the format of a checkpoint is dictated by the state_dict and how this is stored in files on disk. Each weight is associated with a string key that identifies it in the state dict. If the string identifier of the keys in the stored checkpoints don’t match up exactly with those in the model definition, you’ll either run into explicit errors (loading the state dict will raise an exception) or worse - silent errors (loading will succeed but training or inference will not work as expected). In addition to the keys lining up, you also need the shapes of the weights (values in the state_dict) to match up exactly with those expected by the model definition.

Let’s look at the two popular formats for Llama2.

Meta Format

This is the format supported by the official Llama2 implementation. When you download the Llama2 7B model from the meta-llama website, you’ll get access to a single .pth checkpoint file. You can inspect the contents of this checkpoint easily with torch.load

>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>    print(f'{key}: {value.shape}')

tok_embeddings.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
292

The state_dict contains 292 keys, including an input embedding table called tok_embeddings. The model definition for this state_dict expects an embedding layer with 32000 tokens each having a embedding with dim of 4096.

HF Format

This is the most popular format within the Hugging Face Model Hub and is the default format in every torchtune config. This is also the format you get when you download the llama2 model from the Llama-2-7b-hf repo.

The first big difference is that the state_dict is split across two .bin files. To correctly load the checkpoint, you’ll need to piece these files together. Let’s inspect one of the files.

>>> import torch
>>> state_dict = torch.load('pytorch_model-00001-of-00002.bin', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>     print(f'{key}: {value.shape}')

model.embed_tokens.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
241

Not only does the state_dict contain fewer keys (expected since this is one of two files), but the embedding table is called model.embed_tokens instead of tok_embeddings. This mismatch in names will cause an exception when you try to load the state_dict. The size of this layer is the same between the two, which is as expected.


As you can see, if you’re not careful you’ll likely end up making a number of errors just during checkpoint load and save. The torchtune checkpointer makes this less error-prone by managing state dicts for you. torchtune is designed to be “state-dict invariant”.

  • When loading, torchtune accepts checkpoints from multiple sources in multiple formats. You don’t have to worry about explicitly converting checkpoints every time you run a recipe.

  • When saving, torchtune produces checkpoints in the same format as the source. This includes converting the state_dict back into the original form and splitting the keys and weights across the same number of files.

One big advantage of being “state-dict invariant” is that you should be able to use fine-tuned checkpoints from torchtune with any post-training tool (quantization, eval, inference) which supports the source format, without any code changes OR conversion scripts. This is one of the ways in which torchtune interoperates with the surrounding ecosystem.

To be “state-dict invariant”, the load_checkpoint and save_checkpoint methods make use of the weight convertors available here.


Handling different Checkpoint Formats

torchtune supports three different checkpointers, each of which supports a different checkpoint format.

HFCheckpointer

This checkpointer reads and writes checkpoints in a format which is compatible with the transformers framework from Hugging Face. As mentioned above, this is the most popular format within the Hugging Face Model Hub and is the default format in every torchtune config.

For this checkpointer to work correctly, we assume that checkpoint_dir contains the necessary checkpoint and json files. The easiest way to make sure everything works correctly is to use the following flow:

  • Download the model from the HF repo using tune download. By default, this will ignore the “safetensors” files.


    tune download meta-llama/Llama-2-7b-hf \
    --output-dir <checkpoint_dir> \
    --hf-token <hf-token>
    
  • Use output_dir specified here as the checkpoint_dir argument for the checkpointer.


The following snippet explains how the HFCheckpointer is setup in torchtune config files.

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama2-7b-hf model we have
    # 2 .bin files. The checkpointer takes care of sorting
    # by id and so the order here does not matter
    checkpoint_files: [
        pytorch_model-00001-of-00002.bin,
        pytorch_model-00002-of-00002.bin,
    ]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state. More on this in the
    # next section
    recipe_checkpoint: null

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: False

Note

Checkpoint conversion to and from HF’s format requires access to model params which are read directly from the config.json file. This helps ensure we either load the weights correctly or error out in case of discrepancy between the HF checkpoint file and torchtune’s model implementations. This json file is downloaded from the hub along with the model checkpoints. More details on how these are used during conversion can be found here.


MetaCheckpointer

This checkpointer reads and writes checkpoints in a format which is compatible with the original meta-llama github repository.

For this checkpointer to work correctly, we assume that checkpoint_dir contains the necessary checkpoint and json files. The easiest way to make sure everything works correctly is to use the following flow:

  • Download the model from the HF repo using tune download. By default, this will ignore the “safetensors” files.


    tune download meta-llama/Llama-2-7b \
    --output-dir <checkpoint_dir> \
    --hf-token <hf-token>
    
  • Use output_dir above as the checkpoint_dir for the checkpointer.


The following snippet explains how the MetaCheckpointer is setup in torchtune config files.

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelMetaCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama2-7b model we have
    # a single .pth file
    checkpoint_files: [consolidated.00.pth]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state. More on this in the
    # next section
    recipe_checkpoint: null

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: False

TorchTuneCheckpointer

This checkpointer reads and writes checkpoints in a format that is compatible with torchtune’s model definition. This does not perform any state_dict conversions and is currently used either for testing or for loading quantized models for generation.


Intermediate vs Final Checkpoints

torchtune Checkpointers support two checkpointing scenarios:

End-of-training Checkpointing

The model weights at the end of a completed training run are written out to file. The checkpointer ensures that the output checkpoint files have the same keys as the input checkpoint file used to begin training. The checkpointer also ensures that the keys are partitioned across the same number of files as the original checkpoint. The output state dict has the following standard format:

{
    "key_1": weight_1,
    "key_2": weight_2,
    ...
}

Mid-training Chekpointing.

If checkpointing in the middle of training, the output checkpoint needs to store additional information to ensure that subsequent training runs can be correctly restarted. In addition to the model checkpoint files, we output a recipe_state.pt file for intermediate checkpoints. These are currently output at the end of each epoch, and contain information such as optimizer state, number of epochs completed etc.

To prevent us from flooding output_dir with checkpoint files, the recipe state is overwritten at the end of each epoch.

The output state dicts have the following formats:

Model:
    {
        "key_1": weight_1,
        "key_2": weight_2,
        ...
    }

Recipe State:
    {
        "optimizer": ...,
        "epoch": ...,
        ...
    }

To restart from a previous checkpoint file, you’ll need to make the following changes to the config file

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. Note that you will need to update this
    # section of the config with the intermediate checkpoint files
    checkpoint_files: [
        hf_model_0001_0.pt,
        hf_model_0002_0.pt,
    ]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state
    recipe_checkpoint: recipe_state.pt

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: True

Checkpointing for LoRA

In torchtune, we output both the adapter weights and the full model “merged” weights for LoRA. The “merged” checkpoint can be used just like you would use the source checkpoint with any post-training tools. For more details, take a look at our LoRA Finetuning Tutorial.

The primary difference between the two use cases is when you want to resume training from a checkpoint. In this case, the checkpointer needs access to both the initial frozen base model weights as well as the learnt adapter weights. The config for this scenario looks something like this:

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. This is the ORIGINAL frozen checkpoint
    # and NOT the merged checkpoint output during training
    checkpoint_files: [
        pytorch_model-00001-of-00002.bin,
        pytorch_model-00002-of-00002.bin,
    ]

    # this refers to the adapter weights learnt during training
    adapter_checkpoint: adapter_0.pt

    # the file with the checkpoint state
    recipe_checkpoint: recipe_state.pt

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: True

Putting this all together

Let’s now put all of this knowledge together! We’ll load some checkpoints, create some models and run a simple forward.

For this section we’ll use the Llama2 13B model in HF format.

import torch
from torchtune.utils import FullModelHFCheckpointer, ModelType
from torchtune.models.llama2 import llama2_13b

# Set the right directory and files
checkpoint_dir = 'Llama-2-13b-hf/'
pytorch_files = [
    'pytorch_model-00001-of-00003.bin',
    'pytorch_model-00002-of-00003.bin',
    'pytorch_model-00003-of-00003.bin'
]

# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir=checkpoint_dir,
    model_type=ModelType.LLAMA2
)
torchtune_sd = checkpointer.load_checkpoint()

# Setup the model and the input
model = llama2_13b()

# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
<All keys matched successfully>

# We have 32000 vocab tokens; lets generate an input with 70 tokens
x = torch.randint(0, 32000, (1, 70))

with torch.no_grad():
    model(x)

tensor([[[ -6.3989,  -9.0531,   3.2375,  ...,  -5.2822,  -4.4872,  -5.7469],
    [ -8.6737, -11.0023,   6.8235,  ...,  -2.6819,  -4.2424,  -4.0109],
    [ -4.6915,  -7.3618,   4.1628,  ...,  -2.8594,  -2.5857,  -3.1151],
    ...,
    [ -7.7808,  -8.2322,   2.8850,  ...,  -1.9604,  -4.7624,  -1.6040],
    [ -7.3159,  -8.5849,   1.8039,  ...,  -0.9322,  -5.2010,  -1.6824],
    [ -7.8929,  -8.8465,   3.3794,  ...,  -1.3500,  -4.6145,  -2.5931]]])

You can do this with any model supported by torchtune. You can find a full list of models and model builders here.

We hope this deep-dive provided a deeper insight into the checkpointer and associated utilities in torchtune. Happy tuning!

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources