Shortcuts

lora_gemma

torchtune.models.gemma.lora_gemma(lora_attn_modules: List[Literal['q_proj', 'k_proj', 'v_proj', 'output_proj']], apply_lora_to_mlp: bool = False, *, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, num_kv_heads: int, embed_dim: int, intermediate_dim: int, max_seq_len: int, attn_dropout: float = 0.0, norm_eps: float = 1e-06, rope_base: int = 10000, norm_embeddings: bool = True, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, quantize_base: bool = False) GemmaTransformerDecoder[source]

Return a version of Gemma with LoRA applied based on the passed in configuration. Note: output projection lora is not supported because it is tied to token embeddings

Parameters:
  • lora_attn_modules (List[LORA_ATTN_MODULES]) – list of which linear layers LoRA should be applied to in each self-attention block. Options are {"q_proj", "k_proj", "v_proj", "output_proj"}.

  • apply_lora_to_mlp (bool) – whether to apply LoRA to the MLP in each transformer layer. Default: False

  • vocab_size (int) – number of tokens in vocabulary.

  • num_layers (int) – number of layers in the transformer decoder.

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value

  • head_dim (int) – dimension of head

  • num_kv_heads (int) – number of key and value heads.

  • embed_dim (int) – embedding dimension for self-attention

  • intermediate_dim (int) – intermediate dimension for MLP

  • max_seq_len (int) – maximum sequence length the model will be run with,

  • attn_dropout (float) – dropout value passed onto scaled_dot_product_attention. Default: 0.0

  • norm_eps (float) – epsilon in RMS norms Default: 1e-6

  • rope_base (int) – base for the rotary positional embeddings. Default: 10_000

  • norm_embeddings (bool) – whether to apply layer norm before the self-attention and mlp layers. Default: True

  • lora_rank (int) – rank of each low-rank approximation

  • lora_alpha (float) – scaling factor for the low-rank approximation

  • lora_dropout (float) – LoRA dropout probability. Default: 0.0

  • quantize_base – (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently.

Returns:

Instantiation of Gemma model with LoRA applied to a subset of the attention projections in each layer.

Return type:

GemmaTransformerDecoder

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources