validate_missing_and_unexpected_for_lora¶
- torchtune.modules.peft.validate_missing_and_unexpected_for_lora(lora_attn_modules: List[Literal['q_proj', 'k_proj', 'v_proj', 'output_proj']], apply_lora_to_mlp: bool, apply_lora_to_output: bool, base_missing: Optional[List[str]] = None, base_unexpected: Optional[List[str]] = None, lora_missing: Optional[List[str]] = None, lora_unexpected: Optional[List[str]] = None) None [source]¶
A more memory-efficient way to validate that LoRA state dict loading was done properly.
Similar to
validate_state_dict_for_lora()
, this function uses a model’s LoRA config to check that LoRA and/or base model weights are loaded into the full model correctly. Unlike that function, this method relies only on the values of missing and unexpected as returned by the load_state_dict API with strict=False. This allows us to do the validation without any additional calls to .state_dict(), which use additional memory. This API should only be used for single-device recipes, or on multi-device after https://github.com/pytorch/pytorch/pull/120600.- Parameters:
lora_attn_modules (List[LORA_ATTN_MODULES]) – list of which linear layers LoRA should be applied to in each self-attention block. Options are
{"q_proj", "k_proj", "v_proj", "output_proj"}
.apply_lora_to_mlp (bool) – whether LoRA is applied to each MLP linear.
apply_lora_to_output (bool) – whether LoRA is applied to the final output projection.
base_missing (Optional[List[str]]) – List of missing keys when loading base model weights. Default: None
base_unexpected (Optional[List[str]]) – List of unexpected keys when loading base model weights. Default: None
lora_missing (Optional[List[str]]) – List of missing keys when loading LoRA weights. Default: None
lora_unexpected (Optional[List[str]]) – List of unexpected keys when loading LoRA weights. Default: None
- Returns:
None
- Raises:
AssertionError – if base_missing contains any base model keys.
AssertionError – if base_unexpect is nonempty.
AssertionError – if lora_missing contains any LoRA keys.
AssertionError – if lora_unexpected is nonempty.