Shortcuts

validate_state_dict_for_lora

torchtune.modules.peft.validate_state_dict_for_lora(lora_attn_modules: List[Literal['q_proj', 'k_proj', 'v_proj', 'output_proj']], apply_lora_to_mlp: bool, apply_lora_to_output: bool, full_model_state_dict_keys: List[str], lora_state_dict_keys: Optional[List[str]] = None, base_model_state_dict_keys: Optional[List[str]] = None) None[source]

Validate that the state dict keys for a LoRA model are as expected.

  1. If lora_state_dict_keys are passed, this function will confirm that they match exactly the LoRA param names from the full model (as determined by lora_modules).

  2. If base_model_state_dict_keys are passed, this function will confirm that they are exactly the complement of the LoRA param names from the full model.

  3. If both lora_state_dict_keys and base_model_state_dict_keys are passed, this function will confirm that the full model’s params are exactly their disjoint union.

Parameters:
  • lora_attn_modules (List[LORA_ATTN_MODULES]) – list of which linear layers LoRA should be applied to in each self-attention block. Options are {"q_proj", "k_proj", "v_proj", "output_proj"}.

  • apply_lora_to_mlp (bool) – whether LoRA is applied to each MLP linear.

  • apply_lora_to_output (bool) – whether LoRA is applied to the final output projection.

  • full_model_state_dict_keys (List[str]) – List of keys in the full model state dict.

  • lora_state_dict_keys (Optional[List[str]]) – List of keys in the LoRA state dict. If none, LoRA state dict keys will not be validated.

  • base_model_state_dict_keys (Optional[List[str]]) – List of keys in the base model state dict. If none, base model keys will not be validated.

Returns:

None

Raises:
  • AssertionError – If base model state dict is missing any non-LoRA params from the full model.

  • AssertionError – If LoRA state dict is missing any LoRA params from the full model.

  • AssertionError – If base model state dict has any LoRA params.

  • AssertionError – If LoRA state dict has any non-LoRA params.

  • AssertionError – If base model and LoRA state dicts have overlapping keys.

  • AssertionError – If full model state dict is missing keys from either base model or LoRA state dict.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources