Shortcuts

ModelType

class torchtune.training.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]

ModelType is used by the checkpointer to distinguish between different model architectures.

If you are adding a new model that follows a different format than those in the repo already, you can add a new ModelType to gate on weight conversion logic unique to that model.

Variables:
  • GEMMA (str) – Gemma family of models. See gemma()

  • GEMMA2 (str) – Gemma 2 family of models. See gemma2()

  • LLAMA2 (str) – Llama2 family of models. See llama2()

  • LLAMA3 (str) – Llama3 family of models. See llama3()

  • LLAMA3_2 (str) – Llama3.2 family of models. See llama3_2()

  • LLAMA3_VISION (str) – LLama3 vision family of models. See llama3_2_vision_decoder()

  • MISTRAL (str) – Mistral family of models. See mistral()

  • PHI3_MINI (str) – Phi-3 family of models. See phi3()

  • REWARD (str) – A Llama2, Llama3, or Mistral model with a classification head projecting to a single class for reward modelling. See mistral_reward_7b() or llama2_reward_7b()

  • QWEN2 (str) – Qwen2 family of models. See qwen2()

  • CLIP_TEXT (str) – CLIP text encoder. See clip_text_encoder_large()

Example

>>> # Usage in a checkpointer class
>>> def load_checkpoint(self, ...):
>>>     ...
>>>     if self._model_type == MY_NEW_MODEL:
>>>         state_dict = my_custom_state_dict_mapping(state_dict)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources