ModelType
- class torchtune.utils.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]
ModelType is used by the checkpointer to distinguish between different model architectures.
If you are adding a new model that follows a different format than those in the repo already, you can add a new ModelType to gate on weight conversion logic unique to that model.
Example
>>> # Usage in a checkpointer class >>> def load_checkpoint(self, ...): >>> ... >>> if self._model_type == MY_NEW_MODEL: >>> state_dict = my_custom_state_dict_mapping(state_dict)
- GEMMA = 'gemma'
Gemma family of models. See
gemma()
- LLAMA2 = 'llama2'
Llama2 family of models. See
llama2()
- LLAMA3 = 'llama3'
Llama3 family of models. See
llama3()
- MISTRAL = 'mistral'
Mistral family of models. See
mistral()
- MISTRAL_REWARD = 'mistral_reward'
Mistral model with a classification head. See
mistral_classifier()
- PHI3_MINI = 'phi3_mini'
Phi-3 family of models. See
phi3()