• Docs >
  • Module code >
  • torcheval.metrics.classification.binned_precision_recall_curve
Shortcuts

Source code for torcheval.metrics.classification.binned_precision_recall_curve

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, List, Optional, Tuple, TypeVar, Union

import torch

from torcheval.metrics.functional.classification.binned_precision_recall_curve import (
    _binary_binned_precision_recall_curve_compute,
    _binary_binned_precision_recall_curve_update,
    _binned_precision_recall_curve_param_check,
    _multiclass_binned_precision_recall_curve_compute,
    _multiclass_binned_precision_recall_curve_update,
    _multilabel_binned_precision_recall_curve_compute,
    _multilabel_binned_precision_recall_curve_update,
    _optimization_param_check,
)
from torcheval.metrics.metric import Metric


TBinaryBinnedPrecisionRecallCurve = TypeVar("TBinaryBinnedPrecisionRecallCurve")
TMulticlassBinnedPrecisionRecallCurve = TypeVar("TMulticlassBinnedPrecisionRecallCurve")
TMultilabelBinnedPrecisionRecallCurve = TypeVar("TMultilabelBinnedPrecisionRecallCurve")


[docs]class BinaryBinnedPrecisionRecallCurve( Metric[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ): """ Compute precision recall curve with given thresholds. Its functional version is :func:`torcheval.metrics.functional.binary_binned_precision_recall_curve`. See also :class:`MulticlassBinnedPrecisionRecallCurve <MulticlassBinnedPrecisionRecallCurve>` Args: threshold (Union[int, List[float], torch.Tensor], Optional): an integer representing number of bins, a list of thresholds, or a tensor of thresholds. Examples:: >>> import torch >>> from torcheval.metrics import BinaryBinnedPrecisionRecallCurve >>> input = torch.tensor([0.2, 0.8, 0.5, 0.9]) >>> target = torch.tensor([0, 1, 0, 1]) >>> threshold = 5 >>> metric = BinaryBinnedPrecisionRecallCurve(threshold) >>> metric.update(input, target) >>> metric.compute() (tensor([0.5000, 0.6667, 0.6667, 1.0000, 1.0000, 1.0000]), tensor([1., 1., 1., 1., 0., 0.]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) >>> input = torch.tensor([0.2, 0.3, 0.4, 0.5]) >>> target = torch.tensor([0, 0, 1, 1]) >>> threshold = torch.tensor([0.0000, 0.2500, 0.7500, 1.0000]) >>> metric = BinaryBinnedPrecisionRecallCurve(threshold) >>> metric.update(input, target) >>> metric.compute() (tensor([0.5000, 0.6667, 1.0000, 1.0000, 1.0000]), tensor([1., 1., 0., 0., 0.]), tensor([0.0000, 0.2500, 0.7500, 1.0000])) """
[docs] def __init__( self: TBinaryBinnedPrecisionRecallCurve, *, threshold: Union[int, List[float], torch.Tensor] = 100, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) threshold = ( torch.linspace(0, 1.0, threshold, device=self.device) if isinstance(threshold, int) else torch.tensor(threshold, device=self.device) ) _binned_precision_recall_curve_param_check(threshold) self.threshold = threshold self._add_state("num_tp", torch.zeros(len(threshold), device=self.device)) self._add_state("num_fp", torch.zeros(len(threshold), device=self.device)) self._add_state("num_fn", torch.zeros(len(threshold), device=self.device))
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TBinaryBinnedPrecisionRecallCurve, input: torch.Tensor, target: torch.Tensor, ) -> TBinaryBinnedPrecisionRecallCurve: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions It should be probabilities or logits with shape of (n_sample, ). target (Tensor): Tensor of ground truth labels with shape of (n_samples, ). """ input = input.to(self.device) target = target.to(self.device) num_tp, num_fp, num_fn = _binary_binned_precision_recall_curve_update( input, target, self.threshold ) self.num_tp += num_tp self.num_fp += num_fp self.num_fn += num_fn return self @torch.inference_mode() def compute( self: TBinaryBinnedPrecisionRecallCurve, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: Tuple: - precision (Tensor): Tensor of precision result. Its shape is (n_thresholds + 1, ) - recall (Tensor): Tensor of recall result. Its shape is (n_thresholds + 1, ) - thresholds (Tensor): Tensor of threshold. Its shape is (n_thresholds, ) """ return _binary_binned_precision_recall_curve_compute( self.num_tp, self.num_fp, self.num_fn, self.threshold ) @torch.inference_mode() def merge_state( self: TBinaryBinnedPrecisionRecallCurve, metrics: Iterable[TBinaryBinnedPrecisionRecallCurve], ) -> TBinaryBinnedPrecisionRecallCurve: for metric in metrics: self.num_tp += metric.num_tp.to(self.device) self.num_fp += metric.num_fp.to(self.device) self.num_fn += metric.num_fn.to(self.device) return self
[docs]class MulticlassBinnedPrecisionRecallCurve( Metric[Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]] ): """ Compute precision recall curve with given thresholds. Its functional version is :func:`torcheval.metrics.functional.multiclass_binned_precision_recall_curve`. See also :class:`BinaryBinnedPrecisionRecallCurve <BinaryBinnedPrecisionRecallCurve>` Args: num_classes (int): Number of classes. threshold (Union[int, List[float], torch.Tensor], Optional): a integer representing number of bins, a list of thresholds, or a tensor of thresholds. optimization (str): Choose the optimization to use. Accepted values: "vectorized" and "memory". The "vectorized" optimization makes more use of vectorization but uses more memory; the "memory" optimization uses less memory but takes more steps. Here are the tradeoffs between these two options: - "vectorized": consumes more memory but is faster on some hardware, e.g. modern GPUs. - "memory": consumes less memory but can be significantly slower on some hardware, e.g. modern GPUs Generally, on GPUs, the "vectorized" optimization requires more memory but is faster; the "memory" optimization requires less memory but is slower. On CPUs, the "memory" optimization is recommended in all cases; it uses less memory and is faster. Examples:: >>> import torch >>> from torcheval.metrics import MulticlassBinnedPrecisionRecallCurve >>> metric = MulticlassBinnedPrecisionRecallCurve(num_classes=4) >>> input = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.5, 0.5, 0.5, 0.5], [0.7, 0.7, 0.7, 0.7], [0.8, 0.8, 0.8, 0.8]]) >>> target = torch.tensor([0, 1, 2, 3]) >>> threshold = 10 >>> metric.update(input, target) >>> metric.compute() ([tensor([0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000]), tensor([0.2500, 0.3333, 0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000]), tensor([0.2500, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000, 0.5000, 0.0000, 1.0000, 1.0000, 1.0000]), tensor([0.2500, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000])], [tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), tensor([1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]), tensor([1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]), tensor([1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.])], tensor([0.0000, 0.1111, 0.2222, 0.3333, 0.4444, 0.5556, 0.6667, 0.7778, 0.8889, 1.0000])) """
[docs] def __init__( self: TMulticlassBinnedPrecisionRecallCurve, *, num_classes: int, threshold: Union[int, List[float], torch.Tensor] = 100, optimization: str = "vectorized", device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) _optimization_param_check(optimization) threshold = ( torch.linspace(0, 1.0, threshold, device=self.device) if isinstance(threshold, int) else torch.tensor(threshold, device=self.device) ) _binned_precision_recall_curve_param_check(threshold) self.num_classes = num_classes self.threshold = threshold self.optimization = optimization self._add_state( "num_tp", torch.zeros(len(threshold), self.num_classes, device=self.device), ) self._add_state( "num_fp", torch.zeros(len(threshold), self.num_classes, device=self.device), ) self._add_state( "num_fn", torch.zeros(len(threshold), self.num_classes, device=self.device), )
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TMulticlassBinnedPrecisionRecallCurve, input: torch.Tensor, target: torch.Tensor, ) -> TMulticlassBinnedPrecisionRecallCurve: """ Update states with the ground truth labels and predictions. Args: input: Tensor of label predictions It should be probabilities or logits with shape of (n_sample, ). target: Tensor of ground truth labels with shape of (n_samples, ). """ input = input.to(self.device) target = target.to(self.device) num_tp, num_fp, num_fn = _multiclass_binned_precision_recall_curve_update( input, target, num_classes=self.num_classes, threshold=self.threshold, optimization=self.optimization, ) self.num_tp += num_tp self.num_fp += num_fp self.num_fn += num_fn return self @torch.inference_mode() def compute( self: TMulticlassBinnedPrecisionRecallCurve, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: """ Return: a tuple of (precision: torch.Tensor, recall: torch.Tensor, thresholds: torch.Tensor) precision: Tensor of precision result. Its shape is (n_thresholds + 1, ) recall: Tensor of recall result. Its shape is (n_thresholds + 1, ) thresholds: Tensor of threshold. Its shape is (n_thresholds, ) """ return _multiclass_binned_precision_recall_curve_compute( self.num_tp, self.num_fp, self.num_fn, num_classes=self.num_classes, threshold=self.threshold, ) @torch.inference_mode() def merge_state( self: TMulticlassBinnedPrecisionRecallCurve, metrics: Iterable[TMulticlassBinnedPrecisionRecallCurve], ) -> TMulticlassBinnedPrecisionRecallCurve: for metric in metrics: self.num_tp += metric.num_tp.to(self.device) self.num_fp += metric.num_fp.to(self.device) self.num_fn += metric.num_fn.to(self.device) return self
[docs]class MultilabelBinnedPrecisionRecallCurve( Metric[Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]] ): """ Compute precision recall curve with given thresholds. Its functional version is :func:`torcheval.metrics.functional.multilabel_binned_precision_recall_curve`. See also :class:`BinaryBinnedPrecisionRecallCurve <BinaryBinnedPrecisionRecallCurve>`, :class:`MultilabelBinnedPrecisionRecallCurve <MultilabelBinnedPrecisionRecallCurve>` Args: num_labels (int): Number of labels. threshold (Union[int, List[float], torch.Tensor], Optional): a integer representing number of bins, a list of thresholds, or a tensor of thresholds. optimization (str): Choose the optimization to use. Accepted values: "vectorized" and "memory". Here are the tradeoffs between these two options: - "vectorized": consumes more memory but is faster on some hardware, e.g. modern GPUs. - "memory": consumes less memory but can be significantly slower on some hardware, e.g. modern GPUs Generally, on GPUs, the "vectorized" optimization requires more memory but is faster; the "memory" optimization requires less memory but is slower. On CPUs, the "memory" optimization is recommended in all cases; it uses less memory and is faster. Examples:: >>> import torch >>> from torcheval.metrics import MultilabelBinnedPrecisionRecallCurve >>> input = torch.tensor([[0.75, 0.05, 0.35], [0.45, 0.75, 0.05], [0.05, 0.55, 0.75], [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], [0, 0, 0], [0, 1, 1], [1, 1, 1]]) >>> metric = MultilabelBinnedPrecisionRecallCurve(num_labels=3, threshold=5) >>> metric.update(input, target) >>> metric.compute() ([torch.tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000]), torch.tensor([0.5000, 0.6667, 0.6667, 0.0000, 1.0000, 1.0000]), torch.tensor([0.7500, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])], [torch.tensor([1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000]), torch.tensor([1.0, 1.0, 1.0, 0.0, 0.0, 0.0]), torch.tensor([1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000])], torch.tensor([0.0000, 0.2000, 0.5000, 0.8000, 1.0000])) >>> threshold = torch.tensor([0.0, 0.2, 0.5, 0.8, 1.0]) >>> metric = MultilabelBinnedPrecisionRecallCurve(num_labels=3, threshold=threshold) >>> metric.compute() ([torch.tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000]), torch.tensor([0.5000, 0.6667, 0.6667, 1.0000, 1.0000, 1.0000]), torch.tensor([0.7500, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])], [torch.tensor([1.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000]), torch.tensor([1.0, 1.0, 1.0, 0.0, 0.0, 0.0]), torch.tensor([1.0000, 0.6667, 0.3333, 0.0000, 0.0000, 0.0000])], torch.tensor([0.0000, 0.2000, 0.5000, 0.8000, 1.0000])) """
[docs] def __init__( self: TMultilabelBinnedPrecisionRecallCurve, *, num_labels: int, threshold: Union[int, torch.Tensor] = 100, optimization: str = "vectorized", device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) _optimization_param_check(optimization) threshold = ( torch.linspace(0, 1.0, threshold, device=self.device) if isinstance(threshold, int) else torch.tensor(threshold, device=self.device) ) _binned_precision_recall_curve_param_check(threshold) self.num_labels = num_labels self.threshold = threshold self.optimization = optimization self._add_state( "num_tp", torch.zeros(len(threshold), self.num_labels, device=self.device), ) self._add_state( "num_fp", torch.zeros(len(threshold), self.num_labels, device=self.device), ) self._add_state( "num_fn", torch.zeros(len(threshold), self.num_labels, device=self.device), )
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TMultilabelBinnedPrecisionRecallCurve, input: torch.Tensor, target: torch.Tensor, ) -> TMultilabelBinnedPrecisionRecallCurve: """ Update states with the ground truth labels and predictions. Args: input: Tensor of label predictions It should be probabilities or logits with shape of (n_sample, ). target: Tensor of ground truth labels with shape of (n_samples, ). """ input = input.to(self.device) target = target.to(self.device) num_tp, num_fp, num_fn = _multilabel_binned_precision_recall_curve_update( input, target, num_labels=self.num_labels, threshold=self.threshold, optimization=self.optimization, ) self.num_tp += num_tp self.num_fp += num_fp self.num_fn += num_fn return self @torch.inference_mode() def compute( self: TMultilabelBinnedPrecisionRecallCurve, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: """ Return: a tuple of (precision: torch.Tensor, recall: torch.Tensor, thresholds: torch.Tensor) precision: Tensor of precision result. Its shape is (n_thresholds + 1, ) recall: Tensor of recall result. Its shape is (n_thresholds + 1, ) thresholds: Tensor of threshold. Its shape is (n_thresholds, ) """ return _multilabel_binned_precision_recall_curve_compute( self.num_tp, self.num_fp, self.num_fn, num_labels=self.num_labels, threshold=self.threshold, ) @torch.inference_mode() def merge_state( self: TMultilabelBinnedPrecisionRecallCurve, metrics: Iterable[TMultilabelBinnedPrecisionRecallCurve], ) -> TMultilabelBinnedPrecisionRecallCurve: for metric in metrics: self.num_tp += metric.num_tp.to(self.device) self.num_fp += metric.num_fp.to(self.device) self.num_fn += metric.num_fn.to(self.device) return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources