• Docs >
  • Module code >
  • torcheval.metrics.functional.classification.precision_recall_curve
Shortcuts

Source code for torcheval.metrics.functional.classification.precision_recall_curve

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F

"""
This file contains binary_precision_recall_curve, multiclass_precision_recall_curve and multilabel_precision_recall_curve functions.
"""


[docs]@torch.inference_mode() def binary_precision_recall_curve( input: torch.Tensor, target: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns precision-recall pairs and their corresponding thresholds for binary classification tasks. If a class is missing from the target tensor, its recall values are set to 1.0. Its class version is ``torcheval.metrics.BinaryPrecisionRecallCurve``. See also :func:`multiclass_precision_recall_curve <torcheval.metrics.functional.multiclass_precision_recall_curve>`, :func:`multilabel_precision_recall_curve <torcheval.metrics.functional.multilabel_precision_recall_curve>` Args: input (Tensor): Tensor of label predictions It should be probabilities or logits with shape of (n_sample, ). target (Tensor): Tensor of ground truth labels with shape of (n_samples, ). Returns: Tuple: - precision (Tensor): Tensor of precision result. Its shape is (n_thresholds + 1, ) - recall (Tensor): Tensor of recall result. Its shape is (n_thresholds + 1, ) - thresholds (Tensor): Tensor of threshold. Its shape is (n_thresholds, ) Examples:: >>> import torch >>> from torcheval.metrics.functional import binary_precision_recall_curve >>> input = torch.tensor([0.1, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 0, 1, 1]) >>> binary_precision_recall_curve(input, target) (tensor([0.5000, 0.6667, 1.0000, 1.0000, 1.0000]), tensor([1.0000, 1.0000, 1.0000, 0.5000, 0.0000]), tensor([0.1000, 0.5000, 0.7000, 0.8000])) """ _binary_precision_recall_curve_update(input, target) return _binary_precision_recall_curve_compute(input, target)
def _binary_precision_recall_curve_update( input: torch.Tensor, target: torch.Tensor, ) -> None: _binary_precision_recall_curve_update_input_check(input, target) def _binary_precision_recall_curve_compute( input: torch.Tensor, target: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: precision, recall, threshold = _compute_for_each_class(input, target, 1) return precision, recall, threshold def _binary_precision_recall_curve_update_input_check( input: torch.Tensor, target: torch.Tensor, ) -> None: if input.ndim != 1: raise ValueError( "input should be a one-dimensional tensor, " f"got shape {input.shape}." ) if target.ndim != 1: raise ValueError( "target should be a one-dimensional tensor, " f"got shape {target.shape}." ) if input.shape != target.shape: raise ValueError( "The `input` and `target` should have the same shape, " f"got shapes {input.shape} and {target.shape}." )
[docs]@torch.inference_mode() def multiclass_precision_recall_curve( input: torch.Tensor, target: torch.Tensor, *, num_classes: Optional[int] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ Returns precision-recall pairs and their corresponding thresholds for multi-class classification tasks. If a class is missing from the target tensor, its recall values are set to 1.0. Its class version is ``torcheval.metrics.MulticlassPrecisionRecallCurve``. See also :func:`binary_precision_recall_curve <torcheval.metrics.functional.binary_precision_recall_curve>`, :func:`multilabel_precision_recall_curve <torcheval.metrics.functional.multilabel_precision_recall_curve>` Args: input (Tensor): Tensor of label predictions It should be probabilities or logits with shape of (n_sample, n_class). target (Tensor): Tensor of ground truth labels with shape of (n_samples, ). num_classes (Optional): Number of classes. Set to the second dimension of the input if num_classes is None. Return: a tuple of (precision: List[torch.Tensor], recall: List[torch.Tensor], thresholds: List[torch.Tensor]) precision: List of precision result. Each index indicates the result of a class. recall: List of recall result. Each index indicates the result of a class. thresholds: List of threshold. Each index indicates the result of a class. Examples:: >>> import torch >>> from torcheval.metrics.functional import multiclass_precision_recall_curve >>> input = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.5, 0.5, 0.5, 0.5], [0.7, 0.7, 0.7, 0.7], [0.8, 0.8, 0.8, 0.8]]) >>> target = torch.tensor([0, 1, 2, 3]) >>> multiclass_precision_recall_curve(input, target, num_classes=4) ([tensor([0.2500, 0.0000, 0.0000, 0.0000, 1.0000]), tensor([0.2500, 0.3333, 0.0000, 0.0000, 1.0000]), tensor([0.2500, 0.3333, 0.5000, 0.0000, 1.0000]), tensor([0.2500, 0.3333, 0.5000, 1.0000, 1.0000])], [tensor([1., 0., 0., 0., 0.]), tensor([1., 1., 0., 0., 0.]), tensor([1., 1., 1., 0., 0.]), tensor([1., 1., 1., 1., 0.])], [tensor([0.1000, 0.5000, 0.7000, 0.8000]), tensor([0.1000, 0.5000, 0.7000, 0.8000]), tensor([0.1000, 0.5000, 0.7000, 0.8000]), tensor([0.1000, 0.5000, 0.7000, 0.8000])]) """ if num_classes is None and input.ndim == 2: num_classes = input.shape[1] _multiclass_precision_recall_curve_update(input, target, num_classes) return _multiclass_precision_recall_curve_compute(input, target, num_classes)
def _multiclass_precision_recall_curve_update( input: torch.Tensor, target: torch.Tensor, num_classes: Optional[int], ) -> None: _multiclass_precision_recall_curve_update_input_check(input, target, num_classes) @torch.jit.script def _multiclass_precision_recall_curve_compute( input: torch.Tensor, target: torch.Tensor, num_classes: Optional[int], ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: if num_classes is None: num_classes = input.shape[1] thresholds, indices = input.T.sort(dim=1, descending=True) mask = F.pad(thresholds.diff(dim=1) != 0, [0, 1], value=1.0).flip(1) sizes: List[int] = mask.sum(1).tolist() thresholds = thresholds.flip(1)[mask].split(sizes) arange = torch.arange(num_classes, device=target.device) cmp = target[indices] == arange[:, None] num_tp = cmp.cumsum(1).flip(1) num_fp = (~cmp).cumsum(1).flip(1) # The last precision and recall values are 1.0 and 0.0, respectively precision = F.pad(num_tp / (num_tp + num_fp), [0, 1], value=1.0) recall = F.pad((num_tp / num_tp[:, :1]).nan_to_num_(1.0), [0, 1], value=0.0) mask = F.pad(mask, [0, 1], value=1.0) sizes: List[int] = mask.sum(1).tolist() precision = precision[mask].split(sizes) recall = recall[mask].split(sizes) return list(precision), list(recall), list(thresholds) def _multiclass_precision_recall_curve_update_input_check( input: torch.Tensor, target: torch.Tensor, num_classes: Optional[int], ) -> None: if input.size(0) != target.size(0): raise ValueError( "The `input` and `target` should have the same first dimension, " f"got shapes {input.shape} and {target.shape}." ) if target.ndim != 1: raise ValueError( "target should be a one-dimensional tensor, " f"got shape {target.shape}." ) if not (input.ndim == 2 and (num_classes is None or input.shape[1] == num_classes)): raise ValueError( f"input should have shape of (num_sample, num_classes), " f"got {input.shape} and num_classes={num_classes}." ) @torch.jit.script def _compute_for_each_class( input: torch.Tensor, target: torch.Tensor, pos_label: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: threshold, indices = input.sort(descending=True) mask = F.pad(threshold.diff(dim=0) != 0, [0, 1], value=1.0) num_tp = (target[indices] == pos_label).cumsum(0)[mask] num_fp = (1 - (target[indices] == pos_label).long()).cumsum(0)[mask] precision = (num_tp / (num_tp + num_fp)).flip(0) recall = (num_tp / num_tp[-1]).flip(0) threshold = threshold[mask].flip(0) # The last precision and recall values are 1.0 and 0.0 without a corresponding threshold. # This ensures that the graph starts on the y-axis. precision = torch.cat([precision, precision.new_ones(1)]) recall = torch.cat([recall, recall.new_zeros(1)]) # If recalls are NaNs, set NaNs to 1.0s. if torch.isnan(recall[0]): recall = torch.nan_to_num(recall, 1.0) return precision, recall, threshold
[docs]@torch.inference_mode() def multilabel_precision_recall_curve( input: torch.Tensor, target: torch.Tensor, *, num_labels: Optional[int] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ Returns precision-recall pairs and their corresponding thresholds for multi-label classification tasks. If there are no samples for a label in the target tensor, its recall values are set to 1.0. Its class version is ``torcheval.metrics.MultilabelPrecisionRecallCurve``. See also :func:`binary_precision_recall_curve <torcheval.metrics.functional.binary_precision_recall_curve>`, :func:`multiclass_precision_recall_curve <torcheval.metrics.functional.multiclass_precision_recall_curve>` Args: input (Tensor): Tensor of label predictions It should be probabilities or logits with shape of (n_sample, n_label). target (Tensor): Tensor of ground truth labels with shape of (n_samples, n_label). num_labels (Optional): Number of labels. Return: a tuple of (precision: List[torch.Tensor], recall: List[torch.Tensor], thresholds: List[torch.Tensor]) precision: List of precision result. Each index indicates the result of a label. recall: List of recall result. Each index indicates the result of a label. thresholds: List of threshold. Each index indicates the result of a label. Examples:: >>> import torch >>> from torcheval.metrics.functional import multilabel_precision_recall_curve >>> input = torch.tensor([[0.75, 0.05, 0.35], [0.45, 0.75, 0.05], [0.05, 0.55, 0.75], [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], [0, 0, 0], [0, 1, 1], [1, 1, 1]]) >>> multilabel_precision_recall_curve(input, target, num_labels=3) ([tensor([0.5, 0.5, 1.0, 1.0]), tensor([0.5, 0.66666667, 0.5, 0.0, 1.0]), tensor([0.75, 1.0, 1.0, 1.0])], [tensor([1.0, 0.5, 0.5, 0.0]), tensor([1.0, 1.0, 0.5, 0.0, 0.0]), tensor([1.0, 0.66666667, 0.33333333, 0.0])], [tensor([0.05, 0.45, 0.75]), tensor([0.05, 0.55, 0.65, 0.75]), tensor([0.05, 0.35, 0.75])]) """ if input.ndim != 2: raise ValueError( f"input should be a two-dimensional tensor, got shape {input.shape}." ) if num_labels is None: num_labels = input.shape[1] _multilabel_precision_recall_curve_update(input, target, num_labels) return _multilabel_precision_recall_curve_compute(input, target, num_labels)
def _multilabel_precision_recall_curve_update( input: torch.Tensor, target: torch.Tensor, num_labels: int, ) -> None: _multilabel_precision_recall_curve_update_input_check(input, target, num_labels) @torch.jit.script def _multilabel_precision_recall_curve_compute( input: torch.Tensor, target: torch.Tensor, num_labels: int, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: precisions, recalls, thresholds = [], [], [] for i in range(num_labels): precision, recall, threshold = _compute_for_each_class( input[:, i], target[:, i], 1 ) precisions.append(precision) recalls.append(recall) thresholds.append(threshold) return precisions, recalls, thresholds def _multilabel_precision_recall_curve_update_input_check( input: torch.Tensor, target: torch.Tensor, num_labels: int, ) -> None: if input.shape != target.shape: raise ValueError( "Expected both input.shape and target.shape to have the same shape" f" but got {input.shape} and {target.shape}." ) if input.ndim != 2: raise ValueError( f"input should be a two-dimensional tensor, got shape {input.shape}." ) if input.shape[1] != num_labels: raise ValueError( f"input should have shape of (num_sample, num_labels), " f"got {input.shape} and num_labels={num_labels}." )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources