Shortcuts

ModelType

class torchtune.utils.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]

ModelType is used by the checkpointer to distinguish between different model architectures.

If you are adding a new model that follows a different format than those in the repo already, you can add a new ModelType to gate on weight conversion logic unique to that model.

Example

>>> # Usage in a checkpointer class
>>> def load_checkpoint(self, ...):
>>>     ...
>>>     if self._model_type == MY_NEW_MODEL:
>>>         state_dict = my_custom_state_dict_mapping(state_dict)
GEMMA = 'gemma'

Gemma family of models. See gemma()

LLAMA2 = 'llama2'

Llama2 family of models. See llama2()

LLAMA3 = 'llama3'

Llama3 family of models. See llama3()

MISTRAL = 'mistral'

Mistral family of models. See mistral()

MISTRAL_REWARD = 'mistral_reward'

Mistral model with a classification head. See mistral_classifier()

PHI3_MINI = 'phi3_mini'

Phi-3 family of models. See phi3()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources