ModelType¶
- class torchtune.utils.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
ModelType is used by the checkpointer to distinguish between different model architectures.
If you are adding a new model that follows a different format than those in the repo already, you can add a new ModelType to gate on weight conversion logic unique to that model.
Example
>>> # Usage in a checkpointer class >>> def load_checkpoint(self, ...): >>> ... >>> if self._model_type == MY_NEW_MODEL: >>> state_dict = my_custom_state_dict_mapping(state_dict)
- MISTRAL_REWARD = 'mistral_reward'¶
Mistral model with a classification head. See
mistral_classifier()