Shortcuts

Source code for torcheval.metrics.classification.accuracy

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, Optional, TypeVar

import torch

from torcheval.metrics.functional.classification.accuracy import (
    _accuracy_compute,
    _accuracy_param_check,
    _binary_accuracy_update,
    _multiclass_accuracy_update,
    _multilabel_accuracy_param_check,
    _multilabel_accuracy_update,
    _topk_multilabel_accuracy_param_check,
    _topk_multilabel_accuracy_update,
)
from torcheval.metrics.metric import Metric


TAccuracy = TypeVar("TAccuracy")
TBinaryAccuracy = TypeVar("TBinaryAccuracy")
TMultilabelAccuracy = TypeVar("TMultilabelAccuracy")
TTopKMultilabelAccuracy = TypeVar("TTopKMultilabelAccuracy")


[docs]class MulticlassAccuracy(Metric[torch.Tensor]): """ Compute accuracy score, which is the frequency of input matching target. Its functional version is :func:`torcheval.metrics.functional.multiclass_accuracy`. See also :class:`BinaryAccuracy <BinaryAccuracy>`, :class:`MultilabelAccuracy <MultilabelAccuracy>`, :class:`TopKMultilabelAccuracy <TopKMultilabelAccuracy>` Args: average (str, Optional) - ``'micro'`` [default]: Calculate the metrics globally. - ``'macro'`` : Calculate metrics for each class separately, and return their unweighted mean. Classes with 0 true instances are ignored. - ``None``: Calculate the metric for each class separately, and return the metric for every class. NaN is returned if a class has no sample in ``target``. num_classes: Number of classes. Required for ``'macro'`` and ``None`` average methods. k: Number of top probabilities to be considered. K should be an integer greater than or equal to 1. If k >1, the input tensor must contain probabilities or logits for every class. Examples:: >>> import torch >>> from torcheval.metrics import MulticlassAccuracy >>> metric = MulticlassAccuracy() >>> input = torch.tensor([0, 2, 1, 3]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric.update(input, target) >>> metric.compute() tensor(0.5) >>> metric = MulticlassAccuracy(average=None, num_classes=4) >>> input = torch.tensor([0, 2, 1, 3]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric.update(input, target) >>> metric.compute() tensor([1., 0., 0., 1.]) >>> metric = MulticlassAccuracy(average="macro", num_classes=2) >>> input = torch.tensor([0, 0, 1, 1, 1]) >>> target = torch.tensor([0, 0, 0, 0, 1]) >>> metric.update(input, target) >>> metric.compute() tensor(0.75) >>> metric = MulticlassAccuracy() >>> input = torch.tensor([[0.9, 0.1, 0, 0], [0.1, 0.2, 0.4, 0,3], [0, 1.0, 0, 0], [0, 0, 0.2, 0.8]]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric.update(input, target) >>> metric.compute() tensor(0.5) """
[docs] def __init__( self: TAccuracy, *, average: Optional[str] = "micro", num_classes: Optional[int] = None, k: int = 1, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) _accuracy_param_check(average, num_classes, k) self.average = average self.num_classes = num_classes self.k = k if average == "micro": self._add_state("num_correct", torch.tensor(0.0, device=self.device)) self._add_state("num_total", torch.tensor(0.0, device=self.device)) else: # num_classes is verified to be not None when average != "micro" self._add_state( "num_correct", torch.zeros(num_classes or 0, device=self.device), ) self._add_state( "num_total", torch.zeros(num_classes or 0, device=self.device), )
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update(self: TAccuracy, input: torch.Tensor, target: torch.Tensor) -> TAccuracy: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions It could be the predicted labels, with shape of (n_sample, ). It could also be probabilities or logits with shape of (n_sample, n_class). ``torch.argmax`` will be used to convert input into predicted labels. target (Tensor): Tensor of ground truth labels with shape of (n_sample, ). """ input = input.to(self.device) target = target.to(self.device) num_correct, num_total = _multiclass_accuracy_update( input, target, self.average, self.num_classes, self.k ) self.num_correct += num_correct self.num_total += num_total return self @torch.inference_mode() def compute(self: TAccuracy) -> torch.Tensor: """ Return the accuracy score. NaN is returned if no calls to ``update()`` are made before ``compute()`` is called. """ return _accuracy_compute(self.num_correct, self.num_total, self.average) @torch.inference_mode() def merge_state(self: TAccuracy, metrics: Iterable[TAccuracy]) -> TAccuracy: for metric in metrics: self.num_correct += metric.num_correct.to(self.device) self.num_total += metric.num_total.to(self.device) return self
[docs]class BinaryAccuracy(MulticlassAccuracy): """ Compute binary accuracy score, which is the frequency of input matching target. Its functional version is :func:`torcheval.metrics.functional.binary_accuracy`. See also :class:`MulticlassAccuracy <MulticlassAccuracy>`, :class:`MultilabelAccuracy <MultilabelAccuracy>`, :class:`TopKMultilabelAccuracy <TopKMultilabelAccuracy>` Args: threshold (float, default 0.5): Threshold for converting input into predicted labels for each sample. ``torch.where(input < threshold, 0, 1)`` will be applied to the ``input``. Examples:: >>> import torch >>> from torcheval.metrics import BinaryAccuracy >>> metric = BinaryAccuracy() >>> input = torch.tensor([0, 0, 1, 1]) >>> target = torch.tensor([1, 0, 1, 1]) >>> metric.update(input, target) >>> metric.compute() tensor(0.75) # 3 / 4 >>> metric = BinaryAccuracy(threshold=0.7) >>> input = torch.tensor([0, 0.2, 0.6, 0.7]) >>> target = torch.tensor([1, 0, 1, 1]) >>> metric.update(input, target) >>> metric.compute() tensor(0.5) # 2 / 4 """
[docs] def __init__( self: TBinaryAccuracy, *, threshold: float = 0.5, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) self.threshold = threshold
@torch.inference_mode() def update( self: TBinaryAccuracy, input: torch.Tensor, target: torch.Tensor, ) -> TBinaryAccuracy: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions with shape of (n_sample,). ``torch.where(input < threshold, 0, 1)`` will be applied to the input. target (Tensor): Tensor of ground truth labels with shape of (n_sample,). """ input = input.to(self.device) target = target.to(self.device) num_correct, num_total = _binary_accuracy_update(input, target, self.threshold) self.num_correct += num_correct self.num_total += num_total return self
[docs]class MultilabelAccuracy(MulticlassAccuracy): """ Compute multilabel accuracy score, which is the frequency of input matching target. Its functional version is :func:`torcheval.metrics.functional.multilabel_accuracy`. See also :class:`MulticlassAccuracy <MulticlassAccuracy>`, :class:`BinaryAccuracy <BinaryAccuracy>`, :class:`TopKMultilabelAccuracy <TopKMultilabelAccuracy>` Args: threshold (float, Optional): Threshold for converting input into predicted labels for each sample. ``torch.where(input < threshold, 0, 1)`` will be applied to the ``input``. criteria (str, Optional): - ``'exact_match'`` [default]: The set of labels predicted for a sample must exactly match the corresponding set of labels in target. Also known as subset accuracy. - ``'hamming'``: Fraction of correct labels over total number of labels. - ``'overlap'``: The set of labels predicted for a sample must overlap with the corresponding set of labels in target. - ``'contain'``: The set of labels predicted for a sample must contain the corresponding set of labels in target. - ``'belong'``: The set of labels predicted for a sample must (fully) belong to the corresponding set of labels in target. Examples:: >>> import torch >>> from torcheval.metrics import MultilabelAccuracy >>> metric = MultilabelAccuracy() >>> input = torch.tensor([[0, 1], [1, 1], [0, 0], [0, 1]]) >>> target = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.5) # 2 / 4 >>> metric = MultilabelAccuracy(criteria="hamming") >>> input = torch.tensor([[0, 1], [1, 1], [0, 0], [0, 1]]) >>> target = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.75) # 6 / 8 >>> metric = MultilabelAccuracy(criteria="overlap") >>> input = torch.tensor([[0, 1], [1, 1], [0, 0], [0, 1]]) >>> target = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]]) >>> metric.update(input, target) >>> metric.compute() tensor(1) # 4 / 4 >>> metric = MultilabelAccuracy(criteria="contain") >>> input = torch.tensor([[0, 1], [1, 1], [0, 0], [0, 1]]) >>> target = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.75) # 3 / 4, input[0],input[1],input[2] >>> metric = MultilabelAccuracy(criteria="belong") >>> input = torch.tensor([[0, 1], [1, 1], [0, 0], [0, 1]]) >>> target = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.75) # 3 / 4, input[0],input[1],input[3] """
[docs] def __init__( self: TMultilabelAccuracy, *, threshold: float = 0.5, criteria: str = "exact_match", device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) _multilabel_accuracy_param_check(criteria) self.threshold = threshold self.criteria = criteria
@torch.inference_mode() def update( self: TMultilabelAccuracy, input: torch.Tensor, target: torch.Tensor, ) -> TMultilabelAccuracy: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions with shape of (n_sample, n_class). ``torch.where(input < threshold, 0, 1)`` will be applied to the input. target (Tensor): Tensor of ground truth labels with shape of (n_sample, n_class). """ input = input.to(self.device) target = target.to(self.device) num_correct, num_total = _multilabel_accuracy_update( input, target, self.threshold, self.criteria ) self.num_correct += num_correct self.num_total += num_total return self
[docs]class TopKMultilabelAccuracy(MulticlassAccuracy): """ Compute multilabel accuracy score, which is the frequency of the top k label predicted matching target. Its functional version is :func:`torcheval.metrics.functional.topk_multilabel_accuracy`. See also :class:`MulticlassAccuracy <MulticlassAccuracy>`, :class:`BinaryAccuracy <BinaryAccuracy>`, :class:`MultilabelAccuracy <MultilabelAccuracy>` Args: criteria (string): - ``'exact_match'`` [default]: The set of top-k labels predicted for a sample must exactly match the corresponding set of labels in target. Also known as subset accuracy. - ``'hamming'``: Fraction of top-k correct labels over total number of labels. - ``'overlap'``: The set of top-k labels predicted for a sample must overlap with the corresponding set of labels in target. - ``'contain'``: The set of top-k labels predicted for a sample must contain the corresponding set of labels in target. - ``'belong'``: The set of top-k labels predicted for a sample must (fully) belong to the corresponding set of labels in target. k (int): Number of top probabilities to be considered. K should be an integer greater than or equal to 1. Examples:: >>> import torch >>> from torcheval.metrics import TopKMultilabelAccuracy >>> metric = TopKMultilabelAccuracy(k = 2) >>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]]) >>> metric.update(input, target) >>> metric.compute() tensor(0) # 0 / 4 >>> metric = TopKMultilabelAccuracy(criteria="hamming", k=2) >>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.583) # 7 / 12 >>> metric = TopKMultilabelAccuracy(criteria="overlap", k=2) >>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]]) >>> metric.update(input, target) >>> metric.compute() tensor(1) # 4 / 4 >>> metric = TopKMultilabelAccuracy(criteria="contain", k=2) >>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.5) # 2 / 4 >>> metric = TopKMultilabelAccuracy(criteria="belong", k=2) >>> input = torch.tensor([[0.1, 0.5, 0.2], [0.3, 0.2, 0.1], [0.2, 0.4, 0.5], [0, 0.1, 0.9]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [1, 1, 1], [0, 1, 0]]) >>> metric.update(input, target) >>> metric.compute() tensor(0.25) # 1 / 4 """
[docs] def __init__( self: TTopKMultilabelAccuracy, *, criteria: str = "exact_match", k: int = 1, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) _topk_multilabel_accuracy_param_check(criteria, k) self.criteria = criteria self.k = k
@torch.inference_mode() def update( self: TTopKMultilabelAccuracy, input: torch.Tensor, target: torch.Tensor, ) -> TTopKMultilabelAccuracy: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions with shape of (n_sample, n_class). ``torch.where(input < threshold, 0, 1)`` will be applied to the input. target (Tensor): Tensor of ground truth labels with shape of (n_sample, n_class). """ input = input.to(self.device) target = target.to(self.device) num_correct, num_total = _topk_multilabel_accuracy_update( input, target, self.criteria, self.k ) self.num_correct += num_correct self.num_total += num_total return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources