import numbers
from typing import Callable, Optional, Sequence, Tuple, Union
import torch
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from ignite.metrics.metrics_lambda import MetricsLambda
__all__ = ["ConfusionMatrix", "mIoU", "IoU", "DiceCoefficient", "cmAccuracy", "cmPrecision", "cmRecall", "JaccardIndex"]
[docs]class ConfusionMatrix(Metric):
"""Calculates confusion matrix for multi-class data.
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y_pred` must contain logits and has the following shape (batch_size, num_classes, ...).
If you are doing binary classification, see Note for an example on how to get this.
- `y` should have the following shape (batch_size, ...) and contains ground-truth class indices
with or without the background class. During the computation, argmax of `y_pred` is taken to determine
predicted classes.
Args:
num_classes: Number of classes, should be > 1. See notes for more details.
average: confusion matrix values averaging schema: None, "samples", "recall", "precision".
Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen
samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values
represent class recalls. If `average="precision"` then confusion matrix values are normalized such that
diagonal values represent class precisions.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device: specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
Note:
The confusion matrix is formatted such that columns are predictions and rows are targets.
For example, if you were to plot the matrix, you could correctly assign to the horizontal axis
the label "predicted values" and to the vertical axis the label "actual values".
Note:
In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only
contribute to the confusion matrix and others are neglected. For example, if `num_classes=20` and target index
equal 255 is encountered, then it is filtered out.
Examples:
If you are doing binary classification with a single output unit, you may have to transform your network output,
so that you have one value for each class. E.g. you can transform your network output into a one-hot vector
with:
.. code-block:: python
def binary_one_hot_output_transform(output):
y_pred, y = output
y_pred = torch.sigmoid(y_pred).round().long()
y_pred = ignite.utils.to_onehot(y_pred, 2)
y = y.long()
return y_pred, y
metrics = {
"confusion_matrix": ConfusionMatrix(2, output_transform=binary_one_hot_output_transform),
}
evaluator = create_supervised_evaluator(
model, metrics=metrics, output_transform=lambda x, y, y_pred: (y_pred, y)
)
"""
def __init__(
self,
num_classes: int,
average: Optional[str] = None,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
if average is not None and average not in ("samples", "recall", "precision"):
raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'")
if num_classes <= 1:
raise ValueError("Argument num_classes needs to be > 1")
self.num_classes = num_classes
self._num_examples = 0
self.average = average
super(ConfusionMatrix, self).__init__(output_transform=output_transform, device=device)
[docs] @reinit__is_reduced
def reset(self) -> None:
self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes, dtype=torch.int64, device=self._device)
self._num_examples = 0
def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndimension() < 2:
raise ValueError(
f"y_pred must have shape (batch_size, num_classes (currently set to {self.num_classes}), ...), "
f"but given {y_pred.shape}"
)
if y_pred.shape[1] != self.num_classes:
raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[1]} vs {self.num_classes}")
if not (y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError(
f"y_pred must have shape (batch_size, num_classes (currently set to {self.num_classes}), ...) "
"and y must have shape of (batch_size, ...), "
f"but given {y.shape} vs {y_pred.shape}."
)
y_shape = y.shape
y_pred_shape = y_pred.shape # type: Tuple[int, ...]
if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
if y_shape != y_pred_shape:
raise ValueError("y and y_pred must have compatible shapes.")
[docs] @reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
y_pred, y = output[0].detach(), output[1].detach()
self._num_examples += y_pred.shape[0]
# target is (batch_size, ...)
y_pred = torch.argmax(y_pred, dim=1).flatten()
y = y.flatten()
target_mask = (y >= 0) & (y < self.num_classes)
y = y[target_mask]
y_pred = y_pred[target_mask]
indices = self.num_classes * y + y_pred
m = torch.bincount(indices, minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
self.confusion_matrix += m.to(self.confusion_matrix)
[docs] @sync_all_reduce("confusion_matrix", "_num_examples")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("Confusion matrix must have at least one example before it can be computed.")
if self.average:
self.confusion_matrix = self.confusion_matrix.float()
if self.average == "samples":
return self.confusion_matrix / self._num_examples
else:
return self.normalize(self.confusion_matrix, self.average)
return self.confusion_matrix
[docs] @staticmethod
def normalize(matrix: torch.Tensor, average: str) -> torch.Tensor:
"""Normalize given `matrix` with given `average`."""
if average == "recall":
return matrix / (matrix.sum(dim=1).unsqueeze(1) + 1e-15)
elif average == "precision":
return matrix / (matrix.sum(dim=0) + 1e-15)
else:
raise ValueError("Argument average should be one of 'samples', 'recall', 'precision'")
[docs]def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
r"""Calculates Intersection over Union using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.
.. math:: \text{J}(A, B) = \frac{ \lvert A \cap B \rvert }{ \lvert A \cup B \rvert }
Args:
cm: instance of confusion matrix metric
ignore_index: index to ignore, e.g. background index
Returns:
MetricsLambda
Examples:
.. code-block:: python
train_evaluator = ...
cm = ConfusionMatrix(num_classes=num_classes)
IoU(cm, ignore_index=0).attach(train_evaluator, 'IoU')
state = train_evaluator.run(train_dataset)
# state.metrics['IoU'] -> tensor of shape (num_classes - 1, )
"""
if not isinstance(cm, ConfusionMatrix):
raise TypeError(f"Argument cm should be instance of ConfusionMatrix, but given {type(cm)}")
if not (cm.average in (None, "samples")):
raise ValueError("ConfusionMatrix should have average attribute either None or 'samples'")
if ignore_index is not None:
if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes):
raise ValueError(f"ignore_index should be non-negative integer, but given {ignore_index}")
# Increase floating point precision and pass to CPU
cm = cm.type(torch.DoubleTensor)
iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15) # type: MetricsLambda
if ignore_index is not None:
ignore_idx = ignore_index # type: int # used due to typing issues with mympy
def ignore_index_fn(iou_vector: torch.Tensor) -> torch.Tensor:
if ignore_idx >= len(iou_vector):
raise ValueError(f"ignore_index {ignore_idx} is larger than the length of IoU vector {len(iou_vector)}")
indices = list(range(len(iou_vector)))
indices.remove(ignore_idx)
return iou_vector[indices]
return MetricsLambda(ignore_index_fn, iou)
else:
return iou
[docs]def mIoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
"""Calculates mean Intersection over Union using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.
Args:
cm: instance of confusion matrix metric
ignore_index: index to ignore, e.g. background index
Returns:
MetricsLambda
Examples:
.. code-block:: python
train_evaluator = ...
cm = ConfusionMatrix(num_classes=num_classes)
mIoU(cm, ignore_index=0).attach(train_evaluator, 'mean IoU')
state = train_evaluator.run(train_dataset)
# state.metrics['mean IoU'] -> scalar
"""
iou = IoU(cm=cm, ignore_index=ignore_index).mean() # type: MetricsLambda
return iou
def cmAccuracy(cm: ConfusionMatrix) -> MetricsLambda:
"""Calculates accuracy using :class:`~ignite.metrics.metric.ConfusionMatrix` metric.
Args:
cm: instance of confusion matrix metric
Returns:
MetricsLambda
"""
# Increase floating point precision and pass to CPU
cm = cm.type(torch.DoubleTensor)
accuracy = cm.diag().sum() / (cm.sum() + 1e-15) # type: MetricsLambda
return accuracy
def cmPrecision(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda:
"""Calculates precision using :class:`~ignite.metrics.metric.ConfusionMatrix` metric.
Args:
cm: instance of confusion matrix metric
average: if True metric value is averaged over all classes
Returns:
MetricsLambda
"""
# Increase floating point precision and pass to CPU
cm = cm.type(torch.DoubleTensor)
precision = cm.diag() / (cm.sum(dim=0) + 1e-15) # type: MetricsLambda
if average:
mean = precision.mean() # type: MetricsLambda
return mean
return precision
def cmRecall(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda:
"""
Calculates recall using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.
Args:
cm: instance of confusion matrix metric
average: if True metric value is averaged over all classes
Returns:
MetricsLambda
"""
# Increase floating point precision and pass to CPU
cm = cm.type(torch.DoubleTensor)
recall = cm.diag() / (cm.sum(dim=1) + 1e-15) # type: MetricsLambda
if average:
mean = recall.mean() # type: MetricsLambda
return mean
return recall
[docs]def DiceCoefficient(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
"""Calculates Dice Coefficient for a given :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.
Args:
cm: instance of confusion matrix metric
ignore_index: index to ignore, e.g. background index
"""
if not isinstance(cm, ConfusionMatrix):
raise TypeError(f"Argument cm should be instance of ConfusionMatrix, but given {type(cm)}")
if ignore_index is not None:
if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes):
raise ValueError(f"ignore_index should be non-negative integer, but given {ignore_index}")
# Increase floating point precision and pass to CPU
cm = cm.type(torch.DoubleTensor)
dice = 2.0 * cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-15) # type: MetricsLambda
if ignore_index is not None:
ignore_idx = ignore_index # type: int # used due to typing issues with mympy
def ignore_index_fn(dice_vector: torch.Tensor) -> torch.Tensor:
if ignore_idx >= len(dice_vector):
raise ValueError(
f"ignore_index {ignore_idx} is larger than the length of Dice vector {len(dice_vector)}"
)
indices = list(range(len(dice_vector)))
indices.remove(ignore_idx)
return dice_vector[indices]
return MetricsLambda(ignore_index_fn, dice)
else:
return dice
[docs]def JaccardIndex(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
r"""Calculates the Jaccard Index using :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` metric.
Implementation is based on :meth:`~ignite.metrics.IoU`.
.. math:: \text{J}(A, B) = \frac{ \lvert A \cap B \rvert }{ \lvert A \cup B \rvert }
Args:
cm: instance of confusion matrix metric
ignore_index: index to ignore, e.g. background index
Returns:
MetricsLambda
Examples:
.. code-block:: python
train_evaluator = ...
cm = ConfusionMatrix(num_classes=num_classes)
JaccardIndex(cm, ignore_index=0).attach(train_evaluator, 'JaccardIndex')
state = train_evaluator.run(train_dataset)
# state.metrics['JaccardIndex'] -> tensor of shape (num_classes - 1, )
"""
return IoU(cm, ignore_index)