update_state_dict_for_classifier¶
- torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[source]¶
Validates the state dict for checkpoint loading for a classifier model. To be used prior to a call to
model.load_state_dict(state_dict)
. This function will overwrite theoutput.weight
in the state-dict to be loaded with theoutput.weight
in the model if the shapes for theoutput.weight
do not match. You may also wish to override this behaviour, for example, ifnum_classes
for your checkpoint and model are the same.Concretely, when fine-tuning a classifier model from the checkpoint of a base language model which has
output.weight
of shape[vocab_dim, embed_dim]
, we overwrite theoutput.weight
in the state-dict to be loaded with the randomly initialized[num_classes, embed_dim]
weight in the model. This is done in-place.- Parameters:
state_dict (Dict[str, torch.Tensor]) – state dict to be loaded into the classifier model.
model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – model named parameters from
model.named_parameters()
.force_override (bool) – Whether to replace
output.weight
instate_dict
with the model’soutput.weight
, even if the shapes match.
Notes
output.bias
will be ignored if present instate_dict
- This function will always replace the
output.weight
instate_dict
, if
output.weight != model.output.weight
.
- This function will always replace the
- Raises:
AssertionError – if
state_dict
does not containoutput.weight
.AssertionError – if
model_named_parameters
does not containoutput.weight
.