validate_state_dict_for_lora¶
- torchtune.modules.peft.validate_state_dict_for_lora(lora_attn_modules: List[Literal['q_proj', 'k_proj', 'v_proj', 'output_proj']], apply_lora_to_mlp: bool, apply_lora_to_output: bool, full_model_state_dict_keys: List[str], lora_state_dict_keys: Optional[List[str]] = None, base_model_state_dict_keys: Optional[List[str]] = None) None [source]¶
Validate that the state dict keys for a LoRA model are as expected.
If lora_state_dict_keys are passed, this function will confirm that they match exactly the LoRA param names from the full model (as determined by lora_modules).
If base_model_state_dict_keys are passed, this function will confirm that they are exactly the complement of the LoRA param names from the full model.
If both lora_state_dict_keys and base_model_state_dict_keys are passed, this function will confirm that the full model’s params are exactly their disjoint union.
- Parameters:
lora_attn_modules (List[LORA_ATTN_MODULES]) – list of which linear layers LoRA should be applied to in each self-attention block. Options are
{"q_proj", "k_proj", "v_proj", "output_proj"}
.apply_lora_to_mlp (bool) – whether LoRA is applied to each MLP linear.
apply_lora_to_output (bool) – whether LoRA is applied to the final output projection.
full_model_state_dict_keys (List[str]) – List of keys in the full model state dict.
lora_state_dict_keys (Optional[List[str]]) – List of keys in the LoRA state dict. If none, LoRA state dict keys will not be validated.
base_model_state_dict_keys (Optional[List[str]]) – List of keys in the base model state dict. If none, base model keys will not be validated.
- Returns:
None
- Raises:
AssertionError – If base model state dict is missing any non-LoRA params from the full model.
AssertionError – If LoRA state dict is missing any LoRA params from the full model.
AssertionError – If base model state dict has any LoRA params.
AssertionError – If LoRA state dict has any non-LoRA params.
AssertionError – If base model and LoRA state dicts have overlapping keys.
AssertionError – If full model state dict is missing keys from either base model or LoRA state dict.