Shortcuts

Source code for torcheval.metrics.classification.confusion_matrix

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, Optional, TypeVar

import torch

from torcheval.metrics.functional.classification.confusion_matrix import (
    _binary_confusion_matrix_update,
    _confusion_matrix_compute,
    _confusion_matrix_param_check,
    _confusion_matrix_update,
)
from torcheval.metrics.metric import Metric


TMulticlassConfusionMatrix = TypeVar("TMulticlassConfusionMatrix")
TBinaryConfusionMatrix = TypeVar("TBinaryConfusionMatrix")


[docs]class MulticlassConfusionMatrix(Metric[torch.Tensor]): """ Compute multi-class confusion matrix, a matrix of dimension num_classes x num_classes where each element at position `(i,j)` is the number of examples with true class `i` that were predicted to be class `j`. See also :class:`BinaryConfusionMatrix <BinaryConfusionMatrix>` Args: input (Tensor): Tensor of label predictions. It could be the predicted labels, with shape of (n_sample, ). It could also be probabilities or logits with shape of (n_sample, n_class). ``torch.argmax`` will be used to convert input into predicted labels. target (Tensor): Tensor of ground truth labels with shape of (n_sample, ). num_classes (int): Number of classes. normalize (str): - ``None`` [default]: Give raw counts ('none' also defaults to this) - ``'pred'``: Normalize across the prediction class, i.e. such that the rows add to one. - ``'true'``: Normalize across the condition positive, i.e. such that the columns add to one. - ``'all'``" Normalize across all examples, i.e. such that all matrix entries add to one. device (torch.device): Device for internal tensors Examples:: >>> import torch >>> from torcheval.metrics import MulticlassConfusionMatrix >>> input = torch.tensor([0, 2, 1, 3]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric = MulticlassConfusionMatrix(4) >>> metric.update(input, target) >>> metric.compute() tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) >>> input = torch.tensor([0, 0, 1, 1, 1]) >>> target = torch.tensor([0, 0, 0, 0, 1]) >>> metric = MulticlassConfusionMatrix(2) >>> metric.update(input, target) >>> metric.compute() tensor([[2, 2], [0, 1]]) >>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2]) >>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0]) >>> metric = MulticlassConfusionMatrix(3) >>> metric.update(input, target) >>> metric.compute() tensor([[1, 1, 1], [0, 2, 0], [1, 1, 1]]) >>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2]) >>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0]) >>> metric = MulticlassConfusionMatrix(3) >>> metric.update(input, target) >>> metric.compute() tensor([[1., 1., 1.], [0., 2., 0.], [1., 1., 1.]]) >>> metric.normalized("pred") tensor([[0.5000, 0.2500, 0.5000], [0.0000, 0.5000, 0.0000], [0.5000, 0.2500, 0.5000]]) >>> metric.normalized("true") tensor([[0.3333, 0.3333, 0.3333], [0.0000, 1.0000, 0.0000], [0.3333, 0.3333, 0.3333]]) >>> metric.normalized("all") tensor([[0.1250, 0.1250, 0.1250], [0.0000, 0.2500, 0.0000], [0.1250, 0.1250, 0.1250]]) >>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2]) >>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0]) >>> metric = MulticlassConfusionMatrix(3, normalize="true") >>> metric.update(input, target) >>> metric.compute() tensor([[0.3333, 0.3333, 0.3333], [0.0000, 1.0000, 0.0000], [0.3333, 0.3333, 0.3333]]) >>> metric.normalized(None) tensor([[1., 1., 1.], [0., 2., 0.], [1., 1., 1.]]) >>> input = torch.tensor([0, 0, 1, 1, 1]) >>> target = torch.tensor([0, 0, 0, 0, 1]) >>> metric = MulticlassConfusionMatrix(4) >>> metric.update(input, target) >>> metric.compute() tensor([[2, 2, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) >>> input = torch.tensor([[0.9, 0.1, 0, 0], [0.1, 0.2, 0.4, 0.3], [0, 1.0, 0, 0], [0, 0, 0.2, 0.8]]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric = MulticlassConfusionMatrix(4) >>> metric.update(input, target) >>> metric.compute() tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) """
[docs] def __init__( self: TMulticlassConfusionMatrix, num_classes: int, *, normalize: Optional[str] = None, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) _confusion_matrix_param_check(num_classes, normalize) self.normalize = normalize self.num_classes = num_classes self._add_state( "confusion_matrix", torch.zeros([num_classes, num_classes], device=self.device), )
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TMulticlassConfusionMatrix, input: torch.Tensor, target: torch.Tensor ) -> TMulticlassConfusionMatrix: """ Update Confusion Matrix. Args: input (Tensor): Tensor of label predictions. It could be the predicted labels, with shape of (n_sample, ). It could also be probabilities or logits with shape of (n_sample, n_class). ``torch.argmax`` will be used to convert input into predicted labels. target (Tensor): Tensor of ground truth labels with shape of (n_sample, ). """ input = input.to(self.device) target = target.to(self.device) self.confusion_matrix += _confusion_matrix_update( input, target, self.num_classes ) return self @torch.inference_mode() def compute(self: TMulticlassConfusionMatrix) -> torch.Tensor: """ Return the confusion matrix. """ return _confusion_matrix_compute( self.confusion_matrix, normalize=self.normalize ) @torch.inference_mode() def normalized( self: TMulticlassConfusionMatrix, normalize: Optional[str] = None ) -> torch.Tensor: """ Return the normalized confusion matrix Args: normalize (str): Can be overridden when calling compute() - ``None`` [default]: Give raw counts ('none' also defaults to this) - ``'pred'``: Normalize across the prediction class, i.e. such that the rows add to one. - ``'true'``: Normalize across the condition positive, i.e. such that the columns add to one. - ``'all'``" Normalize across all examples, i.e. such that all matrix entries add to one. """ _confusion_matrix_param_check(self.num_classes, normalize) return _confusion_matrix_compute(self.confusion_matrix, normalize) @torch.inference_mode() def merge_state( self: TMulticlassConfusionMatrix, metrics: Iterable[TMulticlassConfusionMatrix] ) -> TMulticlassConfusionMatrix: for metric in metrics: self.confusion_matrix += metric.confusion_matrix.to(self.device) return self
[docs]class BinaryConfusionMatrix(MulticlassConfusionMatrix): """ Compute binary confusion matrix, a 2 by 2 tensor with counts ( (true positive, false negative) , (false positive, true negative) ) See also :class:`MulticlassConfusionMatrix <MulticlassConfusionMatrix>` Args: input (Tensor): Tensor of label predictions with shape of (n_sample,). ``torch.where(input < threshold, 0, 1)`` will be applied to the input. target (Tensor): Tensor of ground truth labels with shape of (n_sample,). threshold (float, default 0.5): Threshold for converting input into predicted labels for each sample. ``torch.where(input < threshold, 0, 1)`` will be applied to the ``input``. normalize (str): - ``None`` [default]: Give raw counts ('none' also defaults to this) - ``'pred'``: Normalize across the prediction class, i.e. such that the rows add to one. - ``'true'``: Normalize across the condition positive, i.e. such that the columns add to one. - ``'all'``" Normalize across all examples, i.e. such that all matrix entries add to one. device (torch.device): Device for internal tensors Examples:: >>> import torch >>> from torcheval.metrics import BinaryConfusionMatrix >>> input = torch.tensor([0, 1, 0.7, 0.6]) >>> target = torch.tensor([0, 1, 1, 0]) >>> metric = BinaryConfusionMatrix() >>> metric.update(input, target) >>> metric.compute() tensor([[1, 1], [0, 2]]) >>> input = torch.tensor([0, 1, 0.7, 0.6]) >>> target = torch.tensor([0, 1, 1, 0]) >>> metric = BinaryConfusionMatrix(threshold=1) >>> metric.update(input, target) >>> metric.compute() tensor([[0, 1], [2, 1]]) >>> input = torch.tensor([1, 1, 0, 0]) >>> target = torch.tensor([0, 1, 1, 1]) >>> metric = BinaryConfusionMatrix() >>> metric.update(input, target) >>> metric.compute() tensor([[0., 1.], [2., 1.]]) >>> metric.normalized("pred") tensor([[0.0000, 0.5000], [1.0000, 0.5000]]) >>> metric.normalized("true") tensor([[0.0000, 1.0000], [0.6667, 0.3333]]) >>> metric.normalized("all") tensor([[0.0000, 0.5000], [1.0000, 0.5000]]) >>> input = torch.tensor([1, 1, 0, 0]) >>> target = torch.tensor([0, 1, 1, 1]) >>> metric = BinaryConfusionMatrix(normalize="true") >>> metric.update(input, target) >>> metric.compute() tensor([[0.0000, 1.0000], [0.6667, 0.3333]]) >>> metric.normalized(None) tensor([[0., 1.], [2., 1.]]) """
[docs] def __init__( self: TBinaryConfusionMatrix, *, threshold: float = 0.5, normalize: Optional[str] = None, device: Optional[torch.device] = None, ) -> None: super().__init__(num_classes=2, device=device, normalize=normalize) self.threshold = threshold
@torch.inference_mode() def update( self: TBinaryConfusionMatrix, input: torch.Tensor, target: torch.Tensor ) -> TBinaryConfusionMatrix: """ Update the confusion matrix Args: input (Tensor): Tensor of label predictions with shape of (n_sample,). ``torch.where(input < threshold, 0, 1)`` will be applied to the input. target (Tensor): Tensor of ground truth labels with shape of (n_sample,). """ input = input.to(self.device) target = target.to(self.device) self.confusion_matrix += _binary_confusion_matrix_update( input, target, self.threshold ) return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources