.. currentmodule:: torcheval

Use Metrics in TorchEval
========================

PyTorch evaluation metrics are one of the core offerings of TorchEval.
For most metrics, we offer both stateful class-based interfaces that only accumulate necessary data until told to compute the metric, and pure functional interfaces.


Class Metrics
--------------

The class metrics keeps track of metric states, which enables them to be able to calculate values through accumulations and
synchronizations across multiple processes. The base class is :obj:`torcheval.metrics.Metric`.

The core APIs of class metrics are ``update()``, ``compute()`` and ``reset()``.

- ``update()``: Update the metric states with input data. This is often used when new data needs to be added for metric computation.
- ``compute()``: Compute the metric values from the metric state, which are updated by previous ``update()`` calls. The compute frequency can be less than the update frequency.
- ``reset()``: Reset the metric state variables to their default value. Usually this is called at the end of every epoch to clean up metric states.

.. note::

    Class metrics keep track of internal states that are updated by input data passed to ``update()`` calls. This means that metric states should be moved to the
    same device as the input data. You can directly pass in device on initialization or use the ``to(device)`` API. The ``.device`` property shows the device of the metric states.

Below is an example of using class metric in a simple training script.

.. code-block:: python

    import torch
    from torcheval.metrics import MulticlassAccuracy

    device = "cuda" if torch.cuda.is_available() else "cpu"
    metric = MulticlassAccuracy(device=device)
    num_epochs, num_batches, batch_size = 4, 8, 10
    num_classes = 3

    # number of batches between metric computations
    compute_frequency = 2

    for epoch in range(num_epochs):
        for batch_idx in range(num_batches):
            input = torch.randint(high=num_classes, size=(batch_size,), device=device)
            target = torch.randint(high=num_classes, size=(batch_size,), device=device)

            # metric.update() updates the metric state with new data
            metric.update(input, target)

            if (batch_idx + 1) % compute_frequency == 0:
                    print(
                        "Epoch {}/{}, Batch {}/{} --- acc: {:.4f}".format(
                            epoch + 1,
                            num_epochs,
                            batch_idx + 1,
                            num_batches,
                            # metric.compute() returns metric value from all seen data
                            metric.compute(),
                        )
                    )

        # metric.reset() reset metric states. It's typically called after the epoch completes.
        metric.reset()

Save and Load Metrics
^^^^^^^^^^^^^^^^^^^^^

Class metrics also implements the stateful protocol, ``.state_dict()`` and ``.load_state_dict()``. Those functions can be used to save and load metrics.

.. code-block:: python

    import torch
    from torcheval.metrics import MulticlassAccuracy

    metric = MulticlassAccuracy()
    input = torch.tensor([0, 2, 1, 3])
    target = torch.tensor([0, 1, 2, 3])
    metric.update(input, target)

    state_dict = metric.state_dict()
    loaded_metric = MulticlassAccuracy()
    loaded_metric.load_state_dict(state_dict)

    # returns torch.tensor(0.5)
    loaded_metric.compute()


Functional Metrics
------------------


Functional metrics are simple python functions that calculate the metric value from input data. They are light-weighted and relatively faster since they don't need to keep and operate on metric states.
The example below shows calculating metric value with the functional version.

.. code-block:: python

    import torch
    from torcheval.metrics.functional import multiclass_accuracy

    input = torch.tensor([0, 2, 1, 3])
    target = torch.tensor([0, 1, 2, 3])
    # returns torch.tensor(0.5)
    multiclass_accuracy(input, target)