ModelType¶
- class torchtune.training.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
ModelType is used by the checkpointer to distinguish between different model architectures.
If you are adding a new model that follows a different format than those in the repo already, you can add a new ModelType to gate on weight conversion logic unique to that model.
- Variables:
LLAMA3_2 (str) – Llama3.2 family of models. See
llama3_2()
LLAMA3_VISION (str) – LLama3 vision family of models. See
llama3_2_vision_decoder()
REWARD (str) – A Llama2, Llama3, or Mistral model with a classification head projecting to a single class for reward modelling. See
mistral_reward_7b()
orllama2_reward_7b()
Example
>>> # Usage in a checkpointer class >>> def load_checkpoint(self, ...): >>> ... >>> if self._model_type == MY_NEW_MODEL: >>> state_dict = my_custom_state_dict_mapping(state_dict)