ignite.metrics#
Metrics provide a way to compute various quantities of interest in an online fashion without having to store the entire output history of a model.
In practice a user needs to attach the metric instance to an engine. The metric
value is then computed using the output of the engine’s process_function
:
from ignite.engine import Engine
from ignite.metrics import Accuracy
def process_function(engine, batch):
# ...
return y_pred, y
engine = Engine(process_function)
metric = Accuracy()
metric.attach(engine, "accuracy")
# ...
state = engine.run(data)
print(f"Accuracy: {state.metrics['accuracy']}")
If the engine’s output is not in the format (y_pred, y)
or {'y_pred': y_pred, 'y': y, ...}
, the user can
use the output_transform
argument to transform it:
from ignite.engine import Engine
from ignite.metrics import Accuracy
def process_function(engine, batch):
# ...
return {'y_pred': y_pred, 'y_true': y, ...}
engine = Engine(process_function)
def output_transform(output):
# `output` variable is returned by above `process_function`
y_pred = output['y_pred']
y = output['y_true']
return y_pred, y # output format is according to `Accuracy` docs
metric = Accuracy(output_transform=output_transform)
metric.attach(engine, "accuracy")
# ...
state = engine.run(data)
print(f"Accuracy: {state.metrics['accuracy']}")
Warning
Please, be careful when using lambda
functions to setup multiple output_transform
for multiple metrics
# Wrong
# metrics_group = [Accuracy(output_transform=lambda output: output[name]) for name in names]
# As lambda can not store `name` and all `output_transform` will use the last `name`
# A correct way. For example, using functools.partial
from functools import partial
def ot_func(output, name):
return output[name]
metrics_group = [Accuracy(output_transform=partial(ot_func, name=name)) for name in names]
For more details, see here
Note
Most of implemented metrics are adapted to distributed computations and reduce their internal states across supported devices before computing metric value. This can be helpful to run the evaluation on multiple nodes/GPU instances/TPUs with a distributed data sampler. Following code snippet shows in detail how to use metrics:
device = f"cuda:{local_rank}"
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank, ],
output_device=local_rank)
test_sampler = DistributedSampler(test_dataset)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
sampler=test_sampler,
num_workers=num_workers,
pin_memory=True
)
evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy()}, device=device)
Note
Metrics cannot be serialized using pickle module because the implementation is based on lambda functions. Therefore, use the third party library dill to overcome the limitation of pickle.
Reset, Update, Compute API#
User can also call directly the following methods on the metric:
reset()
: resets internal variables and accumulatorsupdate()
: updates internal variables and accumulators with provided batch output(y_pred, y)
compute()
: computes custom metric and return the result
This API gives a more fine-grained/custom usage on how to compute a metric. For example:
from ignite.metrics import Precision
# Define the metric
precision = Precision()
# Start accumulation:
for x, y in data:
y_pred = model(x)
precision.update((y_pred, y))
# Compute the result
print("Precision: ", precision.compute())
# Reset metric
precision.reset()
# Start new accumulation:
for x, y in data:
y_pred = model(x)
precision.update((y_pred, y))
# Compute new result
print("Precision: ", precision.compute())
Metric arithmetics#
Metrics could be combined together to form new metrics. This could be done through arithmetics, such
as metric1 + metric2
, use PyTorch operators, such as (metric1 + metric2).pow(2).mean()
,
or use a lambda function, such as MetricsLambda(lambda a, b: torch.mean(a + b), metric1, metric2)
.
For example:
from ignite.metrics import Precision, Recall
precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()
Note
This example computes the mean of F1 across classes. To combine
precision and recall to get F1 or other F metrics, we have to be careful
that average=False
, i.e. to use the unaveraged precision and recall,
otherwise we will not be computing F-beta metrics.
Metrics also support indexing operation (if metric’s result is a vector/matrix/tensor). For example, this can be useful to compute mean metric (e.g. precision, recall or IoU) ignoring the background:
from ignite.metrics import ConfusionMatrix
cm = ConfusionMatrix(num_classes=10)
iou_metric = IoU(cm)
iou_no_bg_metric = iou_metric[:9] # We assume that the background index is 9
mean_iou_no_bg_metric = iou_no_bg_metric.mean()
# mean_iou_no_bg_metric.compute() -> tensor(0.12345)
How to create a custom metric#
To create a custom metric one needs to create a new class inheriting from Metric
and override
three methods :
reset()
: resets internal variables and accumulatorsupdate()
: updates internal variables and accumulators with provided batch output(y_pred, y)
compute()
: computes custom metric and return the result
For example, we would like to implement for illustration purposes a multi-class accuracy metric with some specific condition (e.g. ignore user-defined classes):
from ignite.metrics import Metric
from ignite.exceptions import NotComputableError
# These decorators helps with distributed settings
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced
class CustomAccuracy(Metric):
def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"):
self.ignored_class = ignored_class
self._num_correct = None
self._num_examples = None
super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)
@reinit__is_reduced
def reset(self):
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(CustomAccuracy, self).reset()
@reinit__is_reduced
def update(self, output):
y_pred, y = output[0].detach(), output[1].detach()
indices = torch.argmax(y_pred, dim=1)
mask = (y != self.ignored_class)
mask &= (indices != self.ignored_class)
y = y[mask]
indices = indices[mask]
correct = torch.eq(indices, y).view(-1)
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]
@sync_all_reduce("_num_examples", "_num_correct")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
return self._num_correct.item() / self._num_examples
We imported necessary classes as Metric
, NotComputableError
and
decorators to adapt the metric for distributed setting. In reset
method, we reset internal variables _num_correct
and _num_examples
which are used to compute the custom metric. In updated
method we define how to update
the internal variables. And finally in compute
method, we compute metric value.
Notice that _num_correct
is a tensor, since in update
we accumulate tensor values. _num_examples
is a python
scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before
adding them to the internal variables.
We can check this implementation in a simple case:
import torch
torch.manual_seed(8)
m = CustomAccuracy(ignored_class=3)
batch_size = 4
num_classes = 5
y_pred = torch.rand(batch_size, num_classes)
y = torch.randint(0, num_classes, size=(batch_size, ))
m.update((y_pred, y))
res = m.compute()
print(y, torch.argmax(y_pred, dim=1))
# Out: tensor([2, 2, 2, 3]) tensor([2, 1, 0, 0])
print(m._num_correct, m._num_examples, res)
# Out: 1 3 0.3333333333333333
Metrics and its usages#
By default, Metrics are epoch-wise, it means
Usages can be user defined by creating a class inheriting for MetricUsage
. See the list below of usages.
Complete list of usages#
Metrics and distributed computations#
In the above example, CustomAccuracy
has reset
, update
, compute
methods decorated
with reinit__is_reduced()
, sync_all_reduce()
. The purpose of these features is to adapt metrics in distributed
computations on supported backend and devices (see ignite.distributed for more details). More precisely, in the above
example we added @sync_all_reduce("_num_examples", "_num_correct")
over compute
method. This means that when compute
method is called, metric’s interal variables self._num_examples
and self._num_correct
are summed up over all participating
devices. Therefore, once collected, these internal variables can be used to compute the final metric value.
Complete list of metrics#
Base class for all Metrics. |
|
Calculates the accuracy for binary, multiclass and multilabel data. |
|
Calculates the average loss according to the passed loss_fn. |
|
Apply a function to other metrics to obtain a new metric. |
|
Calculates the mean absolute error. |
|
Calculates the mean |
|
Calculates the mean squared error. |
|
Calculates confusion matrix for multi-class data. |
|
Calculates the top-k categorical accuracy. |
|
Helper class to compute arithmetic average of a single variable. |
|
Calculates Dice Coefficient for a given |
|
Class for metrics that should be computed on the entire output history of a model. |
|
Calculates F-beta score. |
|
Helper class to compute geometric average of a single variable. |
|
Calculates Intersection over Union using |
|
Calculates mean Intersection over Union using |
|
Calculates precision for binary and multiclass data. |
|
Computes average Peak signal-to-noise ratio (PSNR). |
|
Calculates recall for binary and multiclass data. |
|
Calculates the root mean squared error. |
|
Compute running average of a metric or the output of process function. |
|
Single variable accumulator helper to compute (arithmetic, geometric, harmonic) average of a single variable. |
|
|
Provides metrics for the number of examples processed per second. |
Computes Structual Similarity Index Measure |
- class ignite.metrics.Accuracy(output_transform=<function Accuracy.<lambda>>, is_multilabel=False, device=device(type='cpu'))[source]#
Calculates the accuracy for binary, multiclass and multilabel data.
where is true positives, is true negatives, is false positives and is false negatives.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).
y must be in the following shape (batch_size, …).
y and y_pred must be in the following shape of (batch_size, num_categories, …) and num_categories must be greater than 1 for multilabel cases.
In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:
def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y binary_accuracy = Accuracy(thresholded_output_transform)
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.is_multilabel (bool, optional) – flag to use in multilabel case. By default, False.
device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- class ignite.metrics.Average(output_transform=<function Average.<lambda>>, device=device(type='cpu'))[source]#
Helper class to compute arithmetic average of a single variable.
update
must receive output of the form x.x can be a number or torch.Tensor.
Note
Number of samples is updated following the rule:
+1 if input is a number
+1 if input is a 1D torch.Tensor
+batch_size if input is an ND torch.Tensor. Batch size is the first dimension (shape[0]).
For input x being an ND torch.Tensor with N > 1, the first dimension is seen as the number of samples and is summed up and added to the accumulator: accumulator += x.sum(dim=0)
Examples:
evaluator = ... custom_var_mean = Average(output_transform=lambda output: output['custom_var']) custom_var_mean.attach(evaluator, 'mean_custom_var') state = evaluator.run(dataset) # state.metrics['mean_custom_var'] -> average of output['custom_var']
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- class ignite.metrics.ConfusionMatrix(num_classes, average=None, output_transform=<function ConfusionMatrix.<lambda>>, device=device(type='cpu'))[source]#
Calculates confusion matrix for multi-class data.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.y_pred must contain logits and has the following shape (batch_size, num_classes, …). If you are doing binary classification, see Note for an example on how to get this.
y should have the following shape (batch_size, …) and contains ground-truth class indices with or without the background class. During the computation, argmax of y_pred is taken to determine predicted classes.
- Parameters
num_classes (int) – Number of classes, should be > 1. See notes for more details.
average (str, optional) – confusion matrix values averaging schema: None, “samples”, “recall”, “precision”. Default is None. If average=”samples” then confusion matrix values are normalized by the number of seen samples. If average=”recall” then confusion matrix values are normalized such that diagonal values represent class recalls. If average=”precision” then confusion matrix values are normalized such that diagonal values represent class precisions.
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
Note
In case of the targets y in (batch_size, …) format, target indices between 0 and num_classes only contribute to the confusion matrix and others are neglected. For example, if num_classes=20 and target index equal 255 is encountered, then it is filtered out.
If you are doing binary classification with a single output unit, you may have to transform your network output, so that you have one value for each class. E.g. you can transform your network output into a one-hot vector with:
def binary_one_hot_output_transform(output): y_pred, y = output y_pred = torch.sigmoid(y_pred).round().long() y_pred = ignite.utils.to_onehot(y_pred, 2) y = y.long() return y_pred, y metrics = { "confusion_matrix": ConfusionMatrix(2, output_transform=binary_one_hot_output_transform), } evaluator = create_supervised_evaluator( model, metrics=metrics, output_transform=lambda x, y, y_pred: (y_pred, y) )
- ignite.metrics.DiceCoefficient(cm, ignore_index=None)[source]#
Calculates Dice Coefficient for a given
ConfusionMatrix
metric.- Parameters
cm (ConfusionMatrix) – instance of confusion matrix metric
ignore_index (int, optional) – index to ignore, e.g. background index
- Return type
- class ignite.metrics.EpochMetric(compute_fn, output_transform=<function EpochMetric.<lambda>>, check_compute_fn=True, device=device(type='cpu'))[source]#
Class for metrics that should be computed on the entire output history of a model. Model’s output and targets are restricted to be of shape
(batch_size, n_classes)
. Output datatype should be float32. Target datatype should be long.Warning
Current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.
In distributed configuration, all stored data (output and target) is mutually collected across all processes using all gather collective operation. This can potentially lead to a memory error. Compute method executes
compute_fn
on zero rank process only and final result is broadcasted to all processes.update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.
If target shape is
(batch_size, n_classes)
andn_classes > 1
than it should be binary: e.g.[[0, 1, 0, 1], ]
.- Parameters
compute_fn (callable) – a callable with the signature (torch.tensor, torch.tensor) takes as the input predictions and targets and returns a scalar. Input tensors will be on specified
device
(see arg below).output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.check_compute_fn (bool) – if True,
compute_fn
is run on the first batch of data to ensure there are no issues. If issues exist, user is warned that there might be an issue with thecompute_fn
. Default, True.device (str or torch.device, optional) – optional device specification for internal storage.
Warning
EpochMetricWarning: User is warned that there are issues with
compute_fn
on a batch of data processed. To disable the warning, setcheck_compute_fn=False
.
- ignite.metrics.Fbeta(beta, average=True, precision=None, recall=None, output_transform=None, device=device(type='cpu'))[source]#
Calculates F-beta score.
where is a positive real factor.
- Parameters
beta (float) – weight of precision in harmonic mean
average (bool, optional) – if True, F-beta score is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with F-beta score for each class in multiclass case.
precision (Precision, optional) – precision object metric with average=False to compute F-beta score
recall (Precision, optional) – recall object metric with average=False to compute F-beta score
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. It is used only if precision or recall are not provided.device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- Returns
MetricsLambda, F-beta metric
- Return type
- class ignite.metrics.GeometricAverage(output_transform=<function GeometricAverage.<lambda>>, device=device(type='cpu'))[source]#
Helper class to compute geometric average of a single variable.
update
must receive output of the form x.x can be a positive number or a positive torch.Tensor, such that
torch.log(x)
is not nan.
Note
Number of samples is updated following the rule:
+1 if input is a number
+1 if input is a 1D torch.Tensor
+batch_size if input is a ND torch.Tensor. Batch size is the first dimension (shape[0]).
For input x being an ND torch.Tensor with N > 1, the first dimension is seen as the number of samples and is aggregated and added to the accumulator: accumulator *= prod(x, dim=0)
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- ignite.metrics.IoU(cm, ignore_index=None)[source]#
Calculates Intersection over Union using
ConfusionMatrix
metric.- Parameters
cm (ConfusionMatrix) – instance of confusion matrix metric
ignore_index (int, optional) – index to ignore, e.g. background index
- Returns
MetricsLambda
- Return type
Examples:
train_evaluator = ... cm = ConfusionMatrix(num_classes=num_classes) IoU(cm, ignore_index=0).attach(train_evaluator, 'IoU') state = train_evaluator.run(train_dataset) # state.metrics['IoU'] -> tensor of shape (num_classes - 1, )
- ignite.metrics.mIoU(cm, ignore_index=None)[source]#
Calculates mean Intersection over Union using
ConfusionMatrix
metric.- Parameters
cm (ConfusionMatrix) – instance of confusion matrix metric
ignore_index (int, optional) – index to ignore, e.g. background index
- Returns
MetricsLambda
- Return type
Examples:
train_evaluator = ... cm = ConfusionMatrix(num_classes=num_classes) mIoU(cm, ignore_index=0).attach(train_evaluator, 'mean IoU') state = train_evaluator.run(train_dataset) # state.metrics['mean IoU'] -> scalar
- class ignite.metrics.Loss(loss_fn, output_transform=<function Loss.<lambda>>, batch_size=<function Loss.<lambda>>, device=device(type='cpu'))[source]#
Calculates the average loss according to the passed loss_fn.
- Parameters
loss_fn (callable) – a callable taking a prediction tensor, a target tensor, optionally other arguments, and returns the average loss over all observations in the batch.
output_transform (callable) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. The output is expected to be a tuple (prediction, target) or (prediction, target, kwargs) where kwargs is a dictionary of extra keywords arguments. If extra keywords arguments are provided they are passed to loss_fn.batch_size (callable) – a callable taking a target tensor that returns the first dimension size (usually the batch size).
device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- class ignite.metrics.MeanAbsoluteError(output_transform=<function Metric.<lambda>>, device=device(type='cpu'))[source]#
Calculates the mean absolute error.
where is the prediction tensor and is ground true tensor.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.
- class ignite.metrics.MeanPairwiseDistance(p=2, eps=1e-06, output_transform=<function MeanPairwiseDistance.<lambda>>, device=device(type='cpu'))[source]#
Calculates the mean
PairwiseDistance
. Average of pairwise distances computed on provided batches.update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.
- class ignite.metrics.MeanSquaredError(output_transform=<function Metric.<lambda>>, device=device(type='cpu'))[source]#
Calculates the mean squared error.
where is the prediction tensor and is ground true tensor.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.
- class ignite.metrics.Metric(output_transform=<function Metric.<lambda>>, device=device(type='cpu'))[source]#
Base class for all Metrics.
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- required_output_keys#
dictionary defines required keys to be found in
engine.state.output
if the latter is a dictionary. Default,("y_pred", "y")
. This is useful with custom metrics that can require other arguments than predictionsy_pred
and targetsy
. See notes below for an example.- Type
Note
Let’s implement a custom metric that requires
y_pred
,y
andx
as input forupdate
function. In the example below we show how to setup standard metric like Accuracy and the custom metric using by anevaluator
created withcreate_supervised_evaluator()
method.# https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 import torch import torch.nn as nn from ignite.metrics import Metric, Accuracy from ignite.engine import create_supervised_evaluator class CustomMetric(Metric): required_output_keys = ("y_pred", "y", "x") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def update(self, output): y_pred, y, x = output # ... def reset(self): # ... pass def compute(self): # ... pass model = ... metrics = { "Accuracy": Accuracy(), "CustomMetric": CustomMetric() } evaluator = create_supervised_evaluator( model, metrics=metrics, output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred} ) res = evaluator.run(data)
Changed in version 0.4.2:
required_output_keys
became public attribute.- attach(engine, name, usage=<ignite.metrics.metric.EpochWise object>)[source]#
Attaches current metric to provided engine. On the end of engine’s run, engine.state.metrics dictionary will contain computed metric’s value under provided name.
- Parameters
engine (Engine) – the engine to which the metric must be attached
name (str) – the name of the metric to attach
usage (str or MetricUsage, optional) – the usage of the metric. Valid string values should be
ignite.metrics.EpochWise.usage_name
(default) orignite.metrics.BatchWise.usage_name
.
- Return type
None
Example:
metric = ... metric.attach(engine, "mymetric") assert "mymetric" in engine.run(data).metrics assert metric.is_attached(engine)
Example with usage:
metric = ... metric.attach(engine, "mymetric", usage=BatchWise.usage_name) assert "mymetric" in engine.run(data).metrics assert metric.is_attached(engine, usage=BatchWise.usage_name)
- completed(engine, name)[source]#
Helper method to compute metric’s value and put into the engine. It is automatically attached to the engine with
attach()
.- Parameters
- Return type
None
Changed in version 0.4.3: Added dict in metrics results.
- abstract compute()[source]#
Computes the metric based on it’s accumulated state.
By default, this is called at the end of each epoch.
- Returns
- the actual quantity of interest. However, if a
Mapping
is returned, it will be (shallow) flattened into engine.state.metrics whencompleted()
is called. - Return type
Any
- Raises
NotComputableError – raised when the metric cannot be computed.
- detach(engine, usage=<ignite.metrics.metric.EpochWise object>)[source]#
Detaches current metric from the engine and no metric’s computation is done during the run. This method in conjunction with
attach()
can be useful if several metrics need to be computed with different periods. For example, one metric is computed every training epoch and another metric (e.g. more expensive one) is done every n-th training epoch.- Parameters
engine (Engine) – the engine from which the metric must be detached
usage (str or MetricUsage, optional) – the usage of the metric. Valid string values should be ‘epoch_wise’ (default) or ‘batch_wise’.
- Return type
None
Example:
metric = ... engine = ... metric.detach(engine) assert "mymetric" not in engine.run(data).metrics assert not metric.is_attached(engine)
Example with usage:
metric = ... engine = ... metric.detach(engine, usage="batch_wise") assert "mymetric" not in engine.run(data).metrics assert not metric.is_attached(engine, usage="batch_wise")
- is_attached(engine, usage=<ignite.metrics.metric.EpochWise object>)[source]#
Checks if current metric is attached to provided engine. If attached, metric’s computed value is written to engine.state.metrics dictionary.
- Parameters
engine (Engine) – the engine checked from which the metric should be attached
usage (str or MetricUsage, optional) – the usage of the metric. Valid string values should be ‘epoch_wise’ (default) or ‘batch_wise’.
- Return type
- iteration_completed(engine)[source]#
Helper method to update metric’s computation. It is automatically attached to the engine with
attach()
.- Parameters
engine (Engine) – the engine to which the metric must be attached
- Return type
None
- abstract reset()[source]#
Resets the metric to it’s initial state.
By default, this is called at the start of each epoch.
- Return type
None
- class ignite.metrics.MetricsLambda(f, *args, **kwargs)[source]#
Apply a function to other metrics to obtain a new metric. The result of the new metric is defined to be the result of applying the function to the result of argument metrics.
When update, this metric does not recursively update the metrics it depends on. When reset, all its dependency metrics would be resetted. When attach, all its dependency metrics would be attached automatically (but partially, e.g
is_attached()
will return False).- Parameters
f (callable) – the function that defines the computation
args (sequence) – Sequence of other metrics or something else that will be fed to
f
as arguments.kwargs (Any) –
Example:
precision = Precision(average=False) recall = Recall(average=False) def Fbeta(r, p, beta): return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item() F1 = MetricsLambda(Fbeta, recall, precision, 1) F2 = MetricsLambda(Fbeta, recall, precision, 2) F3 = MetricsLambda(Fbeta, recall, precision, 3) F4 = MetricsLambda(Fbeta, recall, precision, 4)
When check if the metric is attached, if one of its dependency metrics is detached, the metric is considered detached too.
engine = ... precision = Precision(average=False) aP = precision.mean() aP.attach(engine, "aP") assert aP.is_attached(engine) # partially attached assert not precision.is_attached(engine) precision.detach(engine) assert not aP.is_attached(engine) # fully attached assert not precision.is_attached(engine)
- class ignite.metrics.PSNR(data_range, output_transform=<function PSNR.<lambda>>, device=device(type='cpu'))[source]#
Computes average Peak signal-to-noise ratio (PSNR).
where is mean squared error.
y_pred and y must have (batch_size, …) shape.
y_pred and y must have same dtype and same shape.
- Parameters
data_range (int or float) – The data range of the target image (distance between minimum and maximum possible values). For other data types, please set the data range, otherwise an exception will be raised.
output_transform (callable, optional) – A callable that is used to transform the Engine’s process_function’s output into the form expected by the metric.
device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your update arguments ensures the update method is non-blocking. By default, CPU.
Example:
To use with
Engine
andprocess_function
, simply attach the metric instance to the engine. The output of the engine’sprocess_function
needs to be in format of(y_pred, y)
or{'y_pred': y_pred, 'y': y, ...}
.def process_function(engine, batch): # ... return y_pred, y engine = Engine(process_function) psnr = PSNR(data_range=1.0) psnr.attach(engine, "psnr") # ... state = engine.run(data) print(f"PSNR: {state.metrics['psnr']}")
This metric by default accepts Grayscale or RGB images. But if you have YCbCr or YUV images, only Y channel is needed for computing PSNR. And, this can be done with
output_transform
. For instance,def get_y_channel(output): y_pred, y = output # y_pred and y are (B, 3, H, W) and YCbCr or YUV images # let's select y channel return y_pred[:, 0, ...], y[:, 0, ...] psnr = PSNR(data_range=219, output_transform=get_y_channel) psnr.attach(engine, "psnr") # ... state = engine.run(data) print(f"PSNR: {state.metrics['psrn']}")
New in version 0.4.3.
- class ignite.metrics.Precision(output_transform=<function Precision.<lambda>>, average=False, is_multilabel=False, device=device(type='cpu'))[source]#
Calculates precision for binary and multiclass data.
where is true positives and is false positives.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).
y must be in the following shape (batch_size, …).
In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:
def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y precision = Precision(output_transform=thresholded_output_transform)
In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for example, average parameter should be False. This can be done as shown below:
precision = Precision(average=False, is_multilabel=True) recall = Recall(average=False, is_multilabel=True) F1 = precision * recall * 2 / (precision + recall + 1e-20) F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
Warning
In multilabel cases, if average is False, current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.
Warning
In multilabel cases, if average is False, current implementation does not work with distributed computations. Results are not reduced across the GPUs. Computed result corresponds to the local rank’s (single GPU) result.
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.average (bool, optional) – if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
is_multilabel (bool, optional) – parameter should be True and the average is computed across samples, instead of classes.
device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- class ignite.metrics.Recall(output_transform=<function Recall.<lambda>>, average=False, is_multilabel=False, device=device(type='cpu'))[source]#
Calculates recall for binary and multiclass data.
where is true positives and is false negatives.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).
y must be in the following shape (batch_size, …).
In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:
def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y recall = Recall(output_transform=thresholded_output_transform)
In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for example, average parameter should be False. This can be done as shown below:
precision = Precision(average=False, is_multilabel=True) recall = Recall(average=False, is_multilabel=True) F1 = precision * recall * 2 / (precision + recall + 1e-20) F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
Warning
In multilabel cases, if average is False, current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.
Warning
In multilabel cases, if average is False, current implementation does not work with distributed computations. Results are not reduced across the GPUs. Computed result corresponds to the local rank’s (single GPU) result.
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.average (bool, optional) – if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
is_multilabel (bool, optional) – parameter should be True and the average is computed across samples, instead of classes.
device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- class ignite.metrics.RootMeanSquaredError(output_transform=<function Metric.<lambda>>, device=device(type='cpu'))[source]#
Calculates the root mean squared error.
where is the prediction tensor and is ground true tensor.
update
must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.
- class ignite.metrics.RunningAverage(src=None, alpha=0.98, output_transform=None, epoch_bound=True, device=None)[source]#
Compute running average of a metric or the output of process function.
- Parameters
src (Metric or None) – input source: an instance of
Metric
or None. The latter corresponds to engine.state.output which holds the output of process function.alpha (float, optional) – running average decay factor, default 0.98
output_transform (callable, optional) – a function to use to transform the output if src is None and corresponds the output of process function. Otherwise it should be None.
epoch_bound (boolean, optional) – whether the running average should be reset after each epoch (defaults to True).
device (str or torch.device, optional) – specifies which device updates are accumulated on. Should be None when
src
is an instance ofMetric
, as the running average will use thesrc
’s device. Otherwise, defaults to CPU. Only applicable when the computed value from the metric is a tensor.
Examples:
alpha = 0.98 acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) acc_metric.attach(trainer, 'running_avg_accuracy') avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) avg_output.attach(trainer, 'running_avg_loss') @trainer.on(Events.ITERATION_COMPLETED) def log_running_avg_metrics(engine): print("running avg accuracy:", engine.state.metrics['running_avg_accuracy']) print("running avg loss:", engine.state.metrics['running_avg_loss'])
- class ignite.metrics.SSIM(data_range, kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True, output_transform=<function SSIM.<lambda>>, device=device(type='cpu'))[source]#
Computes Structual Similarity Index Measure
- Parameters
data_range (int or float) – Range of the image. Typically,
1.0
or255
.kernel_size (int or list or tuple of int) – Size of the kernel. Default: (11, 11)
sigma (float or list or tuple of float) – Standard deviation of the gaussian kernel. Argument is used if
gaussian=True
. Default: (1.5, 1.5)k1 (float) – Parameter of SSIM. Default: 0.01
k2 (float) – Parameter of SSIM. Default: 0.03
gaussian (bool) –
True
to use gaussian kernel,False
to use uniform kerneloutput_transform (callable, optional) – A callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric.device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
Example:
To use with
Engine
andprocess_function
, simply attach the metric instance to the engine. The output of the engine’sprocess_function
needs to be in the format of(y_pred, y)
or{'y_pred': y_pred, 'y': y, ...}
.y_pred
andy
can be un-normalized or normalized image tensors. Depending on that, the user might need to adjustdata_range
.y_pred
andy
should have the same shape.def process_function(engine, batch): # ... return y_pred, y engine = Engine(process_function) metric = SSIM(data_range=1.0) metric.attach(engine, "ssim")
New in version 0.4.2.
- class ignite.metrics.TopKCategoricalAccuracy(k=5, output_transform=<function TopKCategoricalAccuracy.<lambda>>, device=device(type='cpu'))[source]#
Calculates the top-k categorical accuracy.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.
- class ignite.metrics.VariableAccumulation(op, output_transform=<function VariableAccumulation.<lambda>>, device=device(type='cpu'))[source]#
Single variable accumulator helper to compute (arithmetic, geometric, harmonic) average of a single variable.
update
must receive output of the form x.x can be a number or torch.Tensor.
Note
The class stores input into two public variables: accumulator and num_examples. Number of samples is updated following the rule:
+1 if input is a number
+1 if input is a 1D torch.Tensor
+batch_size if input is a ND torch.Tensor. Batch size is the first dimension (shape[0]).
- Parameters
op (callable) – a callable to update accumulator. Method’s signature is (accumulator, output). For example, to compute arithmetic mean value, op = lambda a, x: a + x.
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.device (str or torch.device) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
- class ignite.metrics.MetricUsage(started, completed, iteration_completed)[source]#
Base class for all usages of metrics.
A usage of metric defines the events when a metric starts to compute, updates and completes. Valid events are from
Events
.- Parameters
started (Events) – event when the metric starts to compute. This event will be associated to
started()
.completed (Events) – event when the metric completes. This event will be associated to
completed()
.iteration_completed (CallableEventWithFilter) – event when the metric updates. This event will be associated to
iteration_completed()
.
- class ignite.metrics.EpochWise[source]#
Epoch-wise usage of Metrics. It’s the default and most common usage of metrics.
Metric’s methods are triggered on the following engine events:
iteration_completed()
on everyITERATION_COMPLETED
.completed()
on everyEPOCH_COMPLETED
.
- class ignite.metrics.BatchWise[source]#
Batch-wise usage of Metrics.
Metric’s methods are triggered on the following engine events:
iteration_completed()
on everyITERATION_COMPLETED
.completed()
on everyITERATION_COMPLETED
.
- class ignite.metrics.BatchFiltered(*args, **kwargs)[source]#
Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.
Metric’s methods are triggered on the following engine events:
iteration_completed()
on filteredITERATION_COMPLETED
.completed()
on everyEPOCH_COMPLETED
.
- Parameters
*args (Any) – Positional arguments to setup
ITERATION_COMPLETED(*args, **kwargs)
**kwargs (Any) – Keyword arguments to setup
ITERATION_COMPLETED(*args, **kwargs)
handled byiteration_completed()
.
- ignite.metrics.metric.reinit__is_reduced(func)[source]#
Helper decorator for distributed configuration.
See ignite.metrics on how to use it.
- ignite.metrics.metric.sync_all_reduce(*attrs)[source]#
Helper decorator for distributed configuration to collect instance attribute value across all participating processes.
See ignite.metrics on how to use it.