validate_missing_and_unexpected_for_lora¶
- torchtune.modules.peft.validate_missing_and_unexpected_for_lora(lora_attn_modules: List[Literal['q_proj', 'k_proj', 'v_proj', 'output_proj']], apply_lora_to_mlp: bool, apply_lora_to_output: bool, base_missing: Optional[List[str]] = None, base_unexpected: Optional[List[str]] = None, lora_missing: Optional[List[str]] = None, lora_unexpected: Optional[List[str]] = None) None [source]¶
A more memory-efficient way to validate that LoRA state dict loading was done properly.
This function uses a model’s LoRA config to check that LoRA and/or base model weights are loaded into the full model correctly. This function relies only on the values of missing and unexpected as returned by the load_state_dict API with strict=False. This allows us to do the validation without any additional calls to .state_dict(), which use additional memory.
- Parameters:
lora_attn_modules (List[LORA_ATTN_MODULES]) – list of which linear layers LoRA should be applied to in each self-attention block. Options are
{"q_proj", "k_proj", "v_proj", "output_proj"}
.apply_lora_to_mlp (bool) – whether LoRA is applied to each MLP linear.
apply_lora_to_output (bool) – whether LoRA is applied to the final output projection.
base_missing (Optional[List[str]]) – List of missing keys when loading base model weights. Default: None
base_unexpected (Optional[List[str]]) – List of unexpected keys when loading base model weights. Default: None
lora_missing (Optional[List[str]]) – List of missing keys when loading LoRA weights. Default: None
lora_unexpected (Optional[List[str]]) – List of unexpected keys when loading LoRA weights. Default: None
- Returns:
None
- Raises:
AssertionError – if base_missing contains any base model keys.
AssertionError – if base_unexpected is nonempty.
AssertionError – if lora_missing contains any LoRA keys.
AssertionError – if lora_unexpected is nonempty.