.. _qlora_finetune_label:

=============================
Fine-Tuning Llama2 with QLoRA
=============================

In this tutorial, we'll learn about `QLoRA <https://arxiv.org/abs/2305.14314>`_, an enhancement on top of
`LoRA <https://arxiv.org/abs/2106.09685>`_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll
walk through how QLoRA can be utilized within torchtune to finetune a Llama2-7b model in <10 GB of memory.
It is highly recommended to first develop an understanding of :ref:`LoRA finetuning in torchtune<lora_finetune_label>`.


.. grid:: 2

    .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn

      * How QLoRA saves memory over LoRA finetuning
      * An overview of QLoRA in torchtune
      * How to run a QLoRA finetune in torchtune

    .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites

      * Be familiar with :ref:`torchtune<overview_label>`
      * Make sure to :ref:`install torchtune<install_label>`
      * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`
      * Be familiar with :ref:`LoRA in torchtune<lora_finetune_label>`

What is QLoRA?
---------------

QLoRA builds on top of LoRA to enable further
memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are
low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of
the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or `bf16 <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#bfloat16_floating-point_format>`_.), and
therefore activations and intermediate gradients computed are in fp32/bf16.

QLoRA further quantizes the base model parameters into a bespoke 4-bit NormalFloat (`NF4 <https://www.youtube.com/watch?v=TPcXVJ1VSRI&t=563s>`_) data type, resulting in 4-8x less parameter memory usage while
largely retaining model accuracy. As a result, the vast majority of parameters only take up 4 bits (as opposed to 16 or 32 bits by bf16/fp32 dtypes). This
quantization is done through the method highlighted in the original `QLoRA paper <https://arxiv.org/abs/2305.14314>`_. Adapter
parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve
accuracy.

The QLoRA authors introduce two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat
type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. torchtune uses
the `NF4Tensor <https://github.com/pytorch-labs/ao/blob/b9beaf351e27133d189b57d6fa725b1a7824a457/torchao/dtypes/nf4tensor.py#L153>`_ abstraction from the `torchao library <https://github.com/pytorch-labs/ao>`_ to build QLoRA components as specified in the paper.
torchao is a PyTorch-native library that allows you to quantize and prune your models.


.. _qlora_core_highlevel:

Using QLoRA to save memory
----------------------------------------

In this section, we'll overview how to apply QLoRA to a :class:`~torchtune.modules.peft.LoRALinear` layer in torchtune. For a deep dive into details on QLoRA in torchtune and underlying abstractions,
please see the :ref:`QLoRA in torchtune deepdive <qlora_deepdive_label>` section of this tutorial.

A core idea of QLoRA is the distinction between compute and storage datatypes (dtypes). Specifically, QLoRA stores base model parameters in 4-bit precision (i.e. the storage dtype), and runs
computation in an original higher precision (the compute dtype), generally either fp32 or bf16. As a first step, QLoRA needs to quantize these base model parameters to 4-bit precision
and store them.

To quantize a :class:`~torchtune.modules.peft.LoRALinear` layer in the QLoRA style, simply pass in the ``quantize_base`` flag as ``True`` into :class:`~torchtune.modules.peft.LoRALinear`. This flag
will result in base model weights being quantized and backed by the ``NF4Tensor`` dtype. Forward passes will also be automatically handled to work with the ``NF4Tensor`` dtype,
specifically, the ``NF4`` base weight will be de-quantized to the compute precision, activation will be computed, and only the 4-bit parameter will be stored for gradient computation
in the backward pass, avoiding extra memory usage that would be incurred by storing the higher precision compute dtype.

Here's an example of creating a quantized ``LoRALinear`` layer in comparison to an unquantized ``LoRALinear`` layer. As we can see, the quantized layer consumes
~8x less memory than the unquantized counterpart.

.. code-block:: python

  import torch
  from torchtune.modules.peft import LoRALinear

  torch.set_default_device("cuda")
  qlora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=True)
  print(torch.cuda.memory_allocated())  # 177,152 bytes
  del qlora_linear
  torch.cuda.empty_cache()
  lora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=False)
  print(torch.cuda.memory_allocated()) # 1,081,344 bytes


Using QLoRA in torchtune
----------------------------

We'll now cover how you can initialize a QLoRA-enabled Llama2-7b model as well as some details around
checkpointing with QLoRA.

With torchtune, you can use a simple builder similar to the LoRA builder (:func:`lora_llama_2_7b <torchtune.models.llama2.lora_llama2_7b>`) to apply QLoRA to Llama2 models. Here's a simple example of
initializing a Llama2-7b model with QLoRA enabled:

.. code-block:: python

  from torchtune.models.llama2 import qlora_llama2_7b

  qlora_model = qlora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])

Under the hood, this will apply LoRA to the ``q_proj`` and ``v_proj`` matrices in all attention layers, and further quantize the base parameters
in these matrices to the ``NF4`` dtype. Note that quantization of base model parameters is only applied to layers that are configured to have
LoRA adapters added. For example, in this case, ``k_proj`` and ``output_proj`` in the attention layers don't have LoRA applied to them, so their
base model parameters are not quantized. We can see this by printing the base model parameter dtypes for a particular attention layer:

.. code-block:: python

  attn = qlora_model.layers[0].attn
  print(type(attn.q_proj.weight))  # <class 'torchao.dtypes.nf4tensor.NF4Tensor'>
  print(type(attn.k_proj.weight))  # <class 'torch.nn.parameter.Parameter'>


Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models.
To integrate well with torchtune's :ref:`checkpointing <checkpointing_label>`, we need to convert ``NF4Tensors`` back to their
original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within
torchtune and beyond (e.g. post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done
in a typical LoRA training flow.

To achieve this, when using torchtune's :func:`lora_llama_2_7b <torchtune.models.llama2.lora_llama2_7b>` builder, we automatically register a hook,
:func:`reparametrize_as_dtype_state_dict_post_hook <torchtune.modules.common_utils.reparametrize_as_dtype_state_dict_post_hook>`,
that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these
converted tensors to the CPU. This offloading is to avoid peaking memory; if we did not, we would have to maintain an entire bf16/fp32 copy of the ``state_dict``
on GPU.


.. _qlora_compile_label:

Putting it all together: QLoRA finetune
-----------------------------------------

Putting it all together, we can now finetune a model using torchtune's :ref:`LoRA single-device finetuning <lora_finetune_recipe_label>` recipe,
with a `QLoRA configuration <https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_.

Make sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.
You can then run the following command to perform a QLoRA finetune of Llama2-7B on a single GPU.

.. code-block:: bash

    tune run lora_finetune_single_device --config llama2/7B_qlora_single_device

.. note::
    Make sure to correctly point to the location of your Llama2 weights and tokenizer. This can be done
    either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`
    or by directly modifying the :code:`7B_qlora_single_device.yaml` file. See our ":ref:`config_tutorial_label`" recipe
    for more details on how you can easily clone and modify torchtune configs.

By default, this run should log peak memory stats at model initialization time and every 100
iterations during training. Let's understand the memory savings enabled by QLoRA on top of LoRA training. LoRA training
can be run as follows:

.. code-block:: bash

    tune run lora_finetune_single_device --config llama2/7B_lora_single_device

You should see the memory usage printed out during model initialization and training. An example log for LoRA model initialization is as follows:

.. code-block:: python

  Memory Stats after model init::
  GPU peak memory allocation: 13.96 GB
  GPU peak memory reserved: 13.98 GB
  GPU peak memory active: 13.96 GB

The following table compares the QLoRA's memory reserved during model initialization and training against vanilla LoRA's.
We can see that QLoRA reduces peak memory by about 35% during model initialization, and about 40% during model training:

==================  ==================================  ================================
Finetuning method    Peak memory reserved, model init    Peak memory reserved, training
==================  ==================================  ================================
LoRA                   13.98 GB                            15.57 GB
QLoRA                  9.13 GB                             9.29 GB
==================  ==================================  ================================

From the logs, one can see that the out-of-the-box training performance is quite slow, slower than 1 iteration per
second:

.. code-block:: python

  1|149|Loss: 0.9157477021217346:   1%|          | 149/25880 [02:08<6:14:19,  1.15it/s

To speed things up, we can leverage `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ to compile our model and run the compiled result. To work with
QLoRA training, a nightly build of PyTorch must be used. To update PyTorch to the latest nightly,
please see `the installation instructions <https://pytorch.org/get-started/locally/>`_. Once updated,
you can specify the compile flag as ``True`` via a config override:

.. code-block:: bash

    tune run lora_finetune_single_device --config llama2/7B_qlora_single_device compile=True

From the logs, we can see about a 200% speed up (after a few hundred iterations once the training has stabilized):

.. code-block:: python

  1|228|Loss: 0.8158286809921265:   1%|          | 228/25880 [11:59<1:48:16,  3.95it/s

A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below.

.. image:: /_static/img/qlora_exp.png

.. note::
    The above figure was generated with W&B. You can use torchtune's :class:`~torchtune.training.metric_logging.WandBLogger`
    to generate similar loss curves, but you will need to install W&B and setup an account separately. For more details on
    using W&B in torchtune, see our ":ref:`wandb_logging`" recipe.

As an exercise, you can also try running some evaluation tasks or manually inspecting generations
output by your saved checkpoints (which can be found in :code:`output_dir`).

In the final section, we'll go over a deep dive on how a QLoRA component can be built from a LoRA component.

.. _qlora_deepdive_label:

Deep-dive: Building QLoRA from LoRA
-----------------------------------------

This deep-dive section resumes from the :ref:`Using QLoRA to save memory<qlora_core_highlevel>` portion of this tutorial and dives into how quantization is done with ``NF4Tensor`` and handled appropriately in the forward pass.

First, we'll begin with
a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_label>` and augmented to support quantization:

.. code-block:: python
  :emphasize-lines: 3, 13, 19, 20, 39, 40, 41

  import torch
  from torch import nn
  import torch.nn.functional as F
  from torchao.dtypes.nf4tensor import linear_nf4, to_nf4

  class LoRALinear(nn.Module):
    def __init__(
      self,
      in_dim: int,
      out_dim: int,
      rank: int,
      alpha: float,
      dropout: float,
      quantize_base: bool
    ):
      # These are the weights from the original pretrained model
      self.linear = nn.Linear(in_dim, out_dim, bias=False)
      self.linear_weight = self.linear.weight
      # Use torchao's to_nf4 API to quantize the base weight if needed.
      if quantize_base:
        self.linear_weight = to_nf4(self.linear_weight)
      # These are the new LoRA params. In general rank << in_dim, out_dim
      self.lora_a = nn.Linear(in_dim, rank, bias=False)
      self.lora_b = nn.Linear(rank, out_dim, bias=False)

      # Rank and alpha are commonly-tuned hyperparameters
      self.rank = rank
      self.alpha = alpha

      # Most implementations also include some dropout
      self.dropout = nn.Dropout(p=dropout)

      # The original params are frozen, and only LoRA params are trainable.
      self.linear.weight.requires_grad = False
      self.lora_a.weight.requires_grad = True
      self.lora_b.weight.requires_grad = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      # frozen_out would be the output of the original model
      if quantize_base:
        # Call into torchao's linear_nf4 to run linear forward pass w/quantized weight.
        frozen_out  = linear_nf4(x, self.weight)
      else:
        frozen_out = F.linear(x, self.weight)

      # lora_a projects inputs down to the much smaller self.rank,
      # then lora_b projects back up to the output dimension
      lora_out = self.lora_b(self.lora_a(self.dropout(x)))

      # Finally, scale by the alpha parameter (normalized by rank)
      # and add to the original model's outputs
      return frozen_out + (self.alpha / self.rank) * lora_out

As mentioned above, torchtune takes a dependency on torchao for some of the core components required for QLoRA. This includes the
``NF4Tensor``, as well as helpful utilities including ``to_nf4`` and ``linear_nf4``.

The key changes on top of the LoRA layer are the usage of the ``to_nf4`` and ``linear_nf4`` APIs.

``to_nf4`` accepts an unquantized (bf16 or fp32) tensor and produces an ``NF4`` representation of the weight. See the `implementation <https://github.com/pytorch-labs/ao/blob/c40358072f99b50cd7e58ec11e0e8d90440e3e25/torchao/dtypes/nf4tensor.py#L587>`_ of ``to_nf4`` for more details.
``linear_nf4`` handles the forward pass and autograd when running with quantized base model weights. It computes the forward pass as a regular
``F.linear`` with the incoming activation and unquantized weight. The quantized weight is saved for backward, as opposed to the unquantized version of the weight, to avoid extra
memory usage due to storing higher precision variables to compute gradients in the backward pass. See `linear_nf4 <https://github.com/pytorch-labs/ao/blob/main/torchao/dtypes/nf4tensor.py#L577>`_ for more details.