Source code for ignite.metrics.multilabel_confusion_matrix
from typing import Callable, Sequence, Union
import torch
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
__all__ = ["MultiLabelConfusionMatrix"]
[docs]class MultiLabelConfusionMatrix(Metric):
"""Calculates a confusion matrix for multi-labelled, multi-class data.
- ``update`` must receive output of the form ``(y_pred, y)``.
- `y_pred` must contain 0s and 1s and has the following shape (batch_size, num_classes, ...).
For example, `y_pred[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample as predicted.
- `y` should have the following shape (batch_size, num_classes, ...) with 0s and 1s. For example,
`y[i, j]` = 1 denotes that the j'th class is one of the labels of the i'th sample according to the ground truth.
- both `y` and `y_pred` must be torch Tensors having any of the following types:
{torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}. They must have the same dimensions.
- The confusion matrix 'M' is of dimension (num_classes, 2, 2).
* M[i, 0, 0] corresponds to count/rate of true negatives of class i
* M[i, 0, 1] corresponds to count/rate of false positives of class i
* M[i, 1, 0] corresponds to count/rate of false negatives of class i
* M[i, 1, 1] corresponds to count/rate of true positives of class i
- The classes present in M are indexed as 0, ... , num_classes-1 as can be inferred from above.
Args:
num_classes: Number of classes, should be > 1.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device: specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
normalized: whether to normalize confusion matrix by its sum or not.
.. versionadded:: 0.4.5
"""
def __init__(
self,
num_classes: int,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
normalized: bool = False,
):
if num_classes <= 1:
raise ValueError("Argument num_classes needs to be > 1")
self.num_classes = num_classes
self._num_examples = 0
self.normalized = normalized
super(MultiLabelConfusionMatrix, self).__init__(output_transform=output_transform, device=device)
[docs] @reinit__is_reduced
def reset(self) -> None:
self.confusion_matrix = torch.zeros(self.num_classes, 2, 2, dtype=torch.int64, device=self._device)
self._num_examples = 0
[docs] @reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_input(output)
y_pred, y = output[0].detach(), output[1].detach()
self._num_examples += y.shape[0]
y_reshaped = y.transpose(0, 1).reshape(self.num_classes, -1)
y_pred_reshaped = y_pred.transpose(0, 1).reshape(self.num_classes, -1)
y_total = y_reshaped.sum(dim=1)
y_pred_total = y_pred_reshaped.sum(dim=1)
tp = (y_reshaped * y_pred_reshaped).sum(dim=1)
fp = y_pred_total - tp
fn = y_total - tp
tn = y_reshaped.shape[1] - tp - fp - fn
self.confusion_matrix += torch.stack([tn, fp, fn, tp], dim=1).reshape(-1, 2, 2).to(self._device)
[docs] @sync_all_reduce("confusion_matrix", "_num_examples")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("Confusion matrix must have at least one example before it can be computed.")
if self.normalized:
conf = self.confusion_matrix.to(dtype=torch.float64)
sums = conf.sum(dim=(1, 2))
return conf / sums[:, None, None]
return self.confusion_matrix
def _check_input(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndimension() < 2:
raise ValueError(
f"y_pred must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
)
if y.ndimension() < 2:
raise ValueError(
f"y must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
)
if y_pred.shape[0] != y.shape[0]:
raise ValueError(f"y_pred and y have different batch size: {y_pred.shape[0]} vs {y.shape[0]}")
if y_pred.shape[1] != self.num_classes:
raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[1]} vs {self.num_classes}")
if y.shape[1] != self.num_classes:
raise ValueError(f"y does not have correct number of classes: {y.shape[1]} vs {self.num_classes}")
if y.shape != y_pred.shape:
raise ValueError("y and y_pred shapes must match.")
valid_types = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
if y_pred.dtype not in valid_types:
raise ValueError(f"y_pred must be of any type: {valid_types}")
if y.dtype not in valid_types:
raise ValueError(f"y must be of any type: {valid_types}")
if not torch.equal(y_pred, y_pred ** 2):
raise ValueError("y_pred must be a binary tensor")
if not torch.equal(y, y ** 2):
raise ValueError("y must be a binary tensor")