Shortcuts

ignite.metrics#

Metrics provide a way to compute various quantities of interest in an online fashion without having to store the entire output history of a model.

Attach Engine API#

The metrics as stated above are computed in a online fashion, which means that the metric instance accumulates some internal counters on each iteration and metric value is computed once the epoch is ended. Internal counters are reset after every epoch. In practice, this is done with the help of three methods: reset(), update() and compute().

Therefore, a user needs to attach the metric instance to the engine so that the above three methods can be triggered on execution of certain Events. The reset() method is triggered on EPOCH_STARTED event and it is responsible to reset the metric to its initial state. The update() method is triggered on ITERATION_COMPLETED event as it updates the state of the metric using the passed batch output. And compute() is triggered on EPOCH_COMPLETED event. It computes the metric based on its accumulated states. The metric value is computed using the output of the engine’s process_function:

from ignite.engine import Engine
from ignite.metrics import Accuracy

def process_function(engine, batch):
    # ...
    return y_pred, y

engine = Engine(process_function)
metric = Accuracy()
metric.attach(engine, "accuracy")
# ...
state = engine.run(data)
print(f"Accuracy: {state.metrics['accuracy']}")

If the engine’s output is not in the format (y_pred, y) or {'y_pred': y_pred, 'y': y, ...}, the user can use the output_transform argument to transform it:

from ignite.engine import Engine
from ignite.metrics import Accuracy

def process_function(engine, batch):
    # ...
    return {'y_pred': y_pred, 'y_true': y, ...}

engine = Engine(process_function)

def output_transform(output):
    # `output` variable is returned by above `process_function`
    y_pred = output['y_pred']
    y = output['y_true']
    return y_pred, y  # output format is according to `Accuracy` docs

metric = Accuracy(output_transform=output_transform)
metric.attach(engine, "accuracy")
# ...
state = engine.run(data)
print(f"Accuracy: {state.metrics['accuracy']}")

Warning

Please, be careful when using lambda functions to setup multiple output_transform for multiple metrics

# Wrong
# metrics_group = [Accuracy(output_transform=lambda output: output[name]) for name in names]
# As lambda can not store `name` and all `output_transform` will use the last `name`

# A correct way. For example, using functools.partial
from functools import partial

def ot_func(output, name):
    return output[name]

metrics_group = [Accuracy(output_transform=partial(ot_func, name=name)) for name in names]

For more details, see here

Note

Most of implemented metrics are adapted to distributed computations and reduce their internal states across supported devices before computing metric value. This can be helpful to run the evaluation on multiple nodes/GPU instances/TPUs with a distributed data sampler. Following code snippet shows in detail how to use metrics:

device = f"cuda:{local_rank}"
model = torch.nn.parallel.DistributedDataParallel(model,
                                                  device_ids=[local_rank, ],
                                                  output_device=local_rank)
test_sampler = DistributedSampler(test_dataset)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    sampler=test_sampler,
    num_workers=num_workers,
    pin_memory=True
)

evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy()}, device=device)

Note

Metrics cannot be serialized using pickle module because the implementation is based on lambda functions. Therefore, use the third party library dill to overcome the limitation of pickle.

Reset, Update, Compute API#

User can also call directly the following methods on the metric:

  • reset() : resets internal variables and accumulators

  • update() : updates internal variables and accumulators with provided batch output (y_pred, y)

  • compute() : computes custom metric and return the result

This API gives a more fine-grained/custom usage on how to compute a metric. For example:

from ignite.metrics import Precision

# Define the metric
precision = Precision()

# Start accumulation:
for x, y in data:
    y_pred = model(x)
    precision.update((y_pred, y))

# Compute the result
print("Precision: ", precision.compute())

# Reset metric
precision.reset()

# Start new accumulation:
for x, y in data:
    y_pred = model(x)
    precision.update((y_pred, y))

# Compute new result
print("Precision: ", precision.compute())

Metric arithmetics#

Metrics could be combined together to form new metrics. This could be done through arithmetics, such as metric1 + metric2, use PyTorch operators, such as (metric1 + metric2).pow(2).mean(), or use a lambda function, such as MetricsLambda(lambda a, b: torch.mean(a + b), metric1, metric2).

For example:

from ignite.metrics import Precision, Recall

precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()

Note

This example computes the mean of F1 across classes. To combine precision and recall to get F1 or other F metrics, we have to be careful that average=False, i.e. to use the unaveraged precision and recall, otherwise we will not be computing F-beta metrics.

Metrics also support indexing operation (if metric’s result is a vector/matrix/tensor). For example, this can be useful to compute mean metric (e.g. precision, recall or IoU) ignoring the background:

from ignite.metrics import ConfusionMatrix

cm = ConfusionMatrix(num_classes=10)
iou_metric = IoU(cm)
iou_no_bg_metric = iou_metric[:9]  # We assume that the background index is 9
mean_iou_no_bg_metric = iou_no_bg_metric.mean()
# mean_iou_no_bg_metric.compute() -> tensor(0.12345)

How to create a custom metric#

To create a custom metric one needs to create a new class inheriting from Metric and override three methods :

  • reset() : resets internal variables and accumulators

  • update() : updates internal variables and accumulators with provided batch output (y_pred, y)

  • compute() : computes custom metric and return the result

For example, we would like to implement for illustration purposes a multi-class accuracy metric with some specific condition (e.g. ignore user-defined classes):

from ignite.metrics import Metric
from ignite.exceptions import NotComputableError

# These decorators helps with distributed settings
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced


class CustomAccuracy(Metric):

    def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"):
        self.ignored_class = ignored_class
        self._num_correct = None
        self._num_examples = None
        super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)

    @reinit__is_reduced
    def reset(self):
        self._num_correct = torch.tensor(0, device=self._device)
        self._num_examples = 0
        super(CustomAccuracy, self).reset()

    @reinit__is_reduced
    def update(self, output):
        y_pred, y = output[0].detach(), output[1].detach()

        indices = torch.argmax(y_pred, dim=1)

        mask = (y != self.ignored_class)
        mask &= (indices != self.ignored_class)
        y = y[mask]
        indices = indices[mask]
        correct = torch.eq(indices, y).view(-1)

        self._num_correct += torch.sum(correct).to(self._device)
        self._num_examples += correct.shape[0]

    @sync_all_reduce("_num_examples", "_num_correct:SUM")
    def compute(self):
        if self._num_examples == 0:
            raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
        return self._num_correct.item() / self._num_examples

We imported necessary classes as Metric, NotComputableError and decorators to adapt the metric for distributed setting. In reset method, we reset internal variables _num_correct and _num_examples which are used to compute the custom metric. In updated method we define how to update the internal variables. And finally in compute method, we compute metric value.

Notice that _num_correct is a tensor, since in update we accumulate tensor values. _num_examples is a python scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before adding them to the internal variables.

We can check this implementation in a simple case:

import torch
torch.manual_seed(8)

m = CustomAccuracy(ignored_class=3)

batch_size = 4
num_classes = 5

y_pred = torch.rand(batch_size, num_classes)
y = torch.randint(0, num_classes, size=(batch_size, ))

m.update((y_pred, y))
res = m.compute()

print(y, torch.argmax(y_pred, dim=1))
# Out: tensor([2, 2, 2, 3]) tensor([2, 1, 0, 0])

print(m._num_correct, m._num_examples, res)
# Out: 1 3 0.3333333333333333

Metrics and its usages#

By default, Metrics are epoch-wise, it means

  • reset() is triggered every EPOCH_STARTED (See Events).

  • update() is triggered every ITERATION_COMPLETED.

  • compute() is triggered every EPOCH_COMPLETED.

Usages can be user defined by creating a class inheriting for MetricUsage. See the list below of usages.

Complete list of usages#

Metrics and distributed computations#

In the above example, CustomAccuracy has reset, update, compute methods decorated with reinit__is_reduced(), sync_all_reduce(). The purpose of these features is to adapt metrics in distributed computations on supported backend and devices (see ignite.distributed for more details). More precisely, in the above example we added @sync_all_reduce("_num_examples", "_num_correct:SUM") over compute method. This means that when compute method is called, metric’s interal variables self._num_examples and self._num_correct:SUM are summed up over all participating devices. We specify the reduction operation self._num_correct:SUM or we keep the default self._num_examples as the default is SUM. We currently support four reduction operations (SUM, MAX, MIN, PRODUCT). Therefore, once collected, these internal variables can be used to compute the final metric value.

Complete list of metrics#

Average

Helper class to compute arithmetic average of a single variable.

GeometricAverage

Helper class to compute geometric average of a single variable.

VariableAccumulation

Single variable accumulator helper to compute (arithmetic, geometric, harmonic) average of a single variable.

Accuracy

Calculates the accuracy for binary, multiclass and multilabel data.

confusion_matrix.ConfusionMatrix

Calculates confusion matrix for multi-class data.

ClassificationReport

Build a text report showing the main classification metrics.

DiceCoefficient

Calculates Dice Coefficient for a given ConfusionMatrix metric.

JaccardIndex

Calculates the Jaccard Index using ConfusionMatrix metric.

IoU

Calculates Intersection over Union using ConfusionMatrix metric.

mIoU

Calculates mean Intersection over Union using ConfusionMatrix metric.

EpochMetric

Class for metrics that should be computed on the entire output history of a model.

Fbeta

Calculates F-beta score.

Frequency

Provides metrics for the number of examples processed per second.

Loss

Calculates the average loss according to the passed loss_fn.

MeanAbsoluteError

Calculates the mean absolute error.

MeanPairwiseDistance

Calculates the mean PairwiseDistance.

MeanSquaredError

Calculates the mean squared error.

metric.Metric

Base class for all Metrics.

metric_group.MetricGroup

A class for grouping metrics so that user could manage them easier.

metrics_lambda.MetricsLambda

Apply a function to other metrics to obtain a new metric.

MultiLabelConfusionMatrix

Calculates a confusion matrix for multi-labelled, multi-class data.

MutualInformation

Calculates the mutual information between input XX and prediction YY.

precision.Precision

Calculates precision for binary, multiclass and multilabel data.

PSNR

Computes average Peak signal-to-noise ratio (PSNR).

recall.Recall

Calculates recall for binary, multiclass and multilabel data.

RootMeanSquaredError

Calculates the root mean squared error.

RunningAverage

Compute running average of a metric or the output of process function.

SSIM

Computes Structural Similarity Index Measure

TopKCategoricalAccuracy

Calculates the top-k categorical accuracy.

Bleu

Calculates the BLEU score.

Rouge

Calculates the Rouge score for multiples Rouge-N and Rouge-L metrics.

RougeL

Calculates the Rouge-L score.

RougeN

Calculates the Rouge-N score.

InceptionScore

Calculates Inception Score.

FID

Calculates Frechet Inception Distance.

CosineSimilarity

Calculates the mean of the cosine similarity.

Entropy

Calculates the mean of entropy.

KLDivergence

Calculates the mean of Kullback-Leibler (KL) divergence.

JSDivergence

Calculates the mean of Jensen-Shannon (JS) divergence.

MaximumMeanDiscrepancy

Calculates the mean of maximum mean discrepancy (MMD).

AveragePrecision

Computes Average Precision accumulating predictions and the ground-truth during an epoch and applying sklearn.metrics.average_precision_score .

CohenKappa

Compute different types of Cohen's Kappa: Non-Wieghted, Linear, Quadratic.

GpuInfo

Provides GPU information: a) used memory percentage, b) gpu utilization percentage values as Metric on each iterations.

PrecisionRecallCurve

Compute precision-recall pairs for different probability thresholds for binary classification task by accumulating predictions and the ground-truth during an epoch and applying sklearn.metrics.precision_recall_curve .

RocCurve

Compute Receiver operating characteristic (ROC) for binary classification task by accumulating predictions and the ground-truth during an epoch and applying sklearn.metrics.roc_curve .

ROC_AUC

Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) accumulating predictions and the ground-truth during an epoch and applying sklearn.metrics.roc_auc_score .

regression.CanberraMetric

Calculates the Canberra Metric.

regression.FractionalAbsoluteError

Calculates the Fractional Absolute Error.

regression.FractionalBias

Calculates the Fractional Bias.

regression.GeometricMeanAbsoluteError

Calculates the Geometric Mean Absolute Error.

regression.GeometricMeanRelativeAbsoluteError

Calculates the Geometric Mean Relative Absolute Error.

regression.ManhattanDistance

Calculates the Manhattan Distance.

regression.MaximumAbsoluteError

Calculates the Maximum Absolute Error.

regression.MeanAbsoluteRelativeError

Calculate Mean Absolute Relative Error (MARE), also known as Mean Absolute Percentage Error (MAPE).

regression.MeanError

Calculates the Mean Error.

regression.MeanNormalizedBias

Calculates the Mean Normalized Bias.

regression.MedianAbsoluteError

Calculates the Median Absolute Error.

regression.MedianAbsolutePercentageError

Calculates the Median Absolute Percentage Error.

regression.MedianRelativeAbsoluteError

Calculates the Median Relative Absolute Error.

regression.PearsonCorrelation

Calculates the Pearson correlation coefficient.

regression.R2Score

Calculates the R-Squared, the coefficient of determination.

regression.WaveHedgesDistance

Calculates the Wave Hedges Distance.

Note

Module ignite.metrics.regression provides implementations of metrics useful for regression tasks. Definitions of metrics are based on Botchkarev 2018, page 30 “Appendix 2. Metrics mathematical definitions”.

Helpers for customizing metrics#

MetricUsage#

class ignite.metrics.metric.MetricUsage(started, completed, iteration_completed)[source]#

Base class for all usages of metrics.

A usage of metric defines the events when a metric starts to compute, updates and completes. Valid events are from Events.

Parameters

EpochWise#

class ignite.metrics.metric.EpochWise[source]#

Epoch-wise usage of Metrics. It’s the default and most common usage of metrics.

Metric’s methods are triggered on the following engine events:

usage_name#

usage name string

Type

str

RunningEpochWise#

class ignite.metrics.metric.RunningEpochWise[source]#

Running epoch-wise usage of Metrics. It’s the running version of the EpochWise metric usage. A metric with such a usage most likely accompanies an EpochWise one to compute a running measure of it e.g. running average.

Metric’s methods are triggered on the following engine events:

usage_name#

usage name string

Type

str

BatchWise#

class ignite.metrics.metric.BatchWise[source]#

Batch-wise usage of Metrics.

Metric’s methods are triggered on the following engine events:

usage_name#

usage name string

Type

str

RunningBatchWise#

class ignite.metrics.metric.RunningBatchWise[source]#

Running batch-wise usage of Metrics. It’s the running version of the EpochWise metric usage. A metric with such a usage could for example accompany a BatchWise one to compute a running measure of it e.g. running average.

Metric’s methods are triggered on the following engine events:

usage_name#

usage name string

Type

str

SingleEpochRunningBatchWise#

class ignite.metrics.metric.SingleEpochRunningBatchWise[source]#

Running batch-wise usage of Metrics in a single epoch. It’s like RunningBatchWise metric usage with the difference that is used during a single epoch.

Metric’s methods are triggered on the following engine events:

usage_name#

usage name string

Type

str

BatchFiltered#

class ignite.metrics.metric.BatchFiltered(*args, **kwargs)[source]#

Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.

Metric’s methods are triggered on the following engine events:

Parameters

reinit__is_reduced#

ignite.metrics.metric.reinit__is_reduced(func)[source]#

Helper decorator for distributed configuration.

See ignite.metrics on how to use it.

Parameters

func (Callable) – A callable to reinit.

Return type

Callable

sync_all_reduce#

ignite.metrics.metric.sync_all_reduce(*attrs)[source]#

Helper decorator for distributed configuration to collect instance attribute value across all participating processes and apply the specified reduction operation.

See ignite.metrics on how to use it.

Parameters

attrs (Any) – attribute names of decorated class

Return type

Callable

Changed in version 0.4.5: - Ability to handle different reduction operations (SUM, MAX, MIN, PRODUCT).