Shortcuts

torcheval.metrics.MulticlassConfusionMatrix

class torcheval.metrics.MulticlassConfusionMatrix(num_classes: int, *, normalize: str | None = None, device: device | None = None)

Compute multi-class confusion matrix, a matrix of dimension num_classes x num_classes where each element at position (i,j) is the number of examples with true class i that were predicted to be class j.

Parameters:
  • input (Tensor) – Tensor of label predictions. It could be the predicted labels, with shape of (n_sample, ). It could also be probabilities or logits with shape of (n_sample, n_class). torch.argmax will be used to convert input into predicted labels.

  • target (Tensor) – Tensor of ground truth labels with shape of (n_sample, ).

  • num_classes (int) – Number of classes.

  • normalize (str) –

    • None [default]:

      Give raw counts (‘none’ also defaults to this)

    • 'pred':

      Normalize across the prediction class, i.e. such that the rows add to one.

    • 'true':

      Normalize across the condition positive, i.e. such that the columns add to one.

    • 'all'

      Normalize across all examples, i.e. such that all matrix entries add to one.

  • device (torch.device) – Device for internal tensors

Examples:

>>> import torch
>>> from torcheval.metrics import MulticlassConfusionMatrix
>>> input = torch.tensor([0, 2, 1, 3])
>>> target = torch.tensor([0, 1, 2, 3])
>>> metric = MulticlassConfusionMatrix(4)
>>> metric.update(input, target)
>>> metric.compute()
tensor([[1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1]])

>>> input = torch.tensor([0, 0, 1, 1, 1])
>>> target = torch.tensor([0, 0, 0, 0, 1])
>>> metric = MulticlassConfusionMatrix(2)
>>> metric.update(input, target)
>>> metric.compute()
tensor([[2, 2],
        [0, 1]])

>>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2])
>>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0])
>>> metric = MulticlassConfusionMatrix(3)
>>> metric.update(input, target)
>>> metric.compute()
tensor([[1, 1, 1],
        [0, 2, 0],
        [1, 1, 1]])

>>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2])
>>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0])
>>> metric = MulticlassConfusionMatrix(3)
>>> metric.update(input, target)
>>> metric.compute()
tensor([[1., 1., 1.],
        [0., 2., 0.],
        [1., 1., 1.]])
>>> metric.normalized("pred")
tensor([[0.5000, 0.2500, 0.5000],
        [0.0000, 0.5000, 0.0000],
        [0.5000, 0.2500, 0.5000]])
>>> metric.normalized("true")
tensor([[0.3333, 0.3333, 0.3333],
        [0.0000, 1.0000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
>>> metric.normalized("all")
tensor([[0.1250, 0.1250, 0.1250],
    [0.0000, 0.2500, 0.0000],
    [0.1250, 0.1250, 0.1250]])

>>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2])
>>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0])
>>> metric = MulticlassConfusionMatrix(3, normalize="true")
>>> metric.update(input, target)
>>> metric.compute()
tensor([[0.3333, 0.3333, 0.3333],
        [0.0000, 1.0000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
>>> metric.normalized(None)
tensor([[1., 1., 1.],
        [0., 2., 0.],
        [1., 1., 1.]])

>>> input = torch.tensor([0, 0, 1, 1, 1])
>>> target = torch.tensor([0, 0, 0, 0, 1])
>>> metric = MulticlassConfusionMatrix(4)
>>> metric.update(input, target)
>>> metric.compute()
tensor([[2, 2, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]])

>>> input = torch.tensor([[0.9, 0.1, 0, 0], [0.1, 0.2, 0.4, 0.3], [0, 1.0, 0, 0], [0, 0, 0.2, 0.8]])
>>> target = torch.tensor([0, 1, 2, 3])
>>> metric = MulticlassConfusionMatrix(4)
>>> metric.update(input, target)
>>> metric.compute()
tensor([[1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1]])
__init__(num_classes: int, *, normalize: str | None = None, device: device | None = None) None

Initialize a metric object and its internal states.

Use self._add_state() to initialize state variables of your metric class. The state variables should be either torch.Tensor, a list of torch.Tensor, a dictionary with torch.Tensor as values, or a deque of torch.Tensor.

Methods

__init__(num_classes, *[, normalize, device])

Initialize a metric object and its internal states.

compute()

Return the confusion matrix.

load_state_dict(state_dict[, strict])

Loads metric state variables from state_dict.

merge_state(metrics)

Implement this method to update the current metric's state variables to be the merged states of the current metric and input metrics.

normalized([normalize])

Return the normalized confusion matrix

reset()

Reset the metric state variables to their default value.

state_dict()

Save metric state variables in state_dict.

to(device, *args, **kwargs)

Move tensors in metric state variables to device.

update(input, target)

Update Confusion Matrix.

Attributes

device

The last input device of Metric.to().

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources