• Docs >
  • Use Metrics in TorchEval
Shortcuts

Use Metrics in TorchEval

PyTorch evaluation metrics are one of the core offerings of TorchEval. For most metrics, we offer both stateful class-based interfaces that only accumulate necessary data until told to compute the metric, and pure functional interfaces.

Class Metrics

The class metrics keeps track of metric states, which enables them to be able to calculate values through accumulations and synchronizations across multiple processes. The base class is torcheval.metrics.Metric.

The core APIs of class metrics are update(), compute() and reset().

  • update(): Update the metric states with input data. This is often used when new data needs to be added for metric computation.
  • compute(): Compute the metric values from the metric state, which are updated by previous update() calls. The compute frequency can be less than the update frequency.
  • reset(): Reset the metric state variables to their default value. Usually this is called at the end of every epoch to clean up metric states.

Note

Class metrics keep track of internal states that are updated by input data passed to update() calls. This means that metric states should be moved to the same device as the input data. You can directly pass in device on initialization or use the to(device) API. The .device property shows the device of the metric states.

Below is an example of using class metric in a simple training script.

import torch
from torcheval.metrics import MulticlassAccuracy

device = "cuda" if torch.cuda.is_available() else "cpu"
metric = MulticlassAccuracy(device=device)
num_epochs, num_batches, batch_size = 4, 8, 10
num_classes = 3

# number of batches between metric computations
compute_frequency = 2

for epoch in range(num_epochs):
    for batch_idx in range(num_batches):
        input = torch.randint(high=num_classes, size=(batch_size,), device=device)
        target = torch.randint(high=num_classes, size=(batch_size,), device=device)

        # metric.update() updates the metric state with new data
        metric.update(input, target)

        if (batch_idx + 1) % compute_frequency == 0:
                print(
                    "Epoch {}/{}, Batch {}/{} --- acc: {:.4f}".format(
                        epoch + 1,
                        num_epochs,
                        batch_idx + 1,
                        num_batches,
                        # metric.compute() returns metric value from all seen data
                        metric.compute(),
                    )
                )

    # metric.reset() reset metric states. It's typically called after the epoch completes.
    metric.reset()

Save and Load Metrics

Class metrics also implements the stateful protocol, .state_dict() and .load_state_dict(). Those functions can be used to save and load metrics.

import torch
from torcheval.metrics import MulticlassAccuracy

metric = MulticlassAccuracy()
input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
metric.update(input, target)

state_dict = metric.state_dict()
loaded_metric = MulticlassAccuracy()
loaded_metric.load_state_dict(state_dict)

# returns torch.tensor(0.5)
loaded_metric.compute()

Functional Metrics

Functional metrics are simple python functions that calculate the metric value from input data. They are light-weighted and relatively faster since they don’t need to keep and operate on metric states. The example below shows calculating metric value with the functional version.

import torch
from torcheval.metrics.functional import multiclass_accuracy

input = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])
# returns torch.tensor(0.5)
multiclass_accuracy(input, target)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources