Source code for torcheval.metrics.classification.recall_at_fixed_precision
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-ignore-all-errors[16]: Undefined attribute of metric states.
from typing import Iterable, List, Optional, Tuple, TypeVar
import torch
from torcheval.metrics.functional.classification.recall_at_fixed_precision import (
_binary_recall_at_fixed_precision_compute,
_binary_recall_at_fixed_precision_update_input_check,
_multilabel_recall_at_fixed_precision_compute,
_multilabel_recall_at_fixed_precision_update_input_check,
)
from torcheval.metrics.metric import Metric
"""
This file contains BinaryRecallAtFixedPrecision and MultilabelRecallAtFixedPrecision classes.
"""
TBinaryRecallAtFixedPrecision = TypeVar("TBinaryRecallAtFixedPrecision")
TMultilabelRecallAtFixedPrecision = TypeVar("TMultilabelRecallAtFixedPrecision")
[docs]class BinaryRecallAtFixedPrecision(Metric[Tuple[torch.Tensor, torch.Tensor]]):
"""
Returns the highest possible recall value give the minimum precision
for binary classification tasks.
Its functional version is :func:`torcheval.metrics.functional.binary_recall_at_fixed_precision`.
See also :class:`MultilabelRecallAtFixedPrecision <MultilabelRecallAtFixedPrecision>`
Args:
min_precision (float): Minimum precision threshold
Examples::
>>> import torch
>>> from torcheval.metrics import BinaryRecallAtFixedPrecision
>>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5)
>>> input = torch.tensor([0.1, 0.4, 0.6, 0.6, 0.6, 0.35, 0.8])
>>> target = torch.tensor([0, 0, 1, 1, 1, 1, 1])
>>> metric.update(input, target)
>>> metric.compute()
(torch.tensor(1.0), torch.tensor(0.35))
"""
[docs] def __init__(
self: TBinaryRecallAtFixedPrecision,
*,
min_precision: float,
device: Optional[torch.device] = None,
) -> None:
super().__init__(device=device)
self.min_precision = min_precision
self._add_state("inputs", [])
self._add_state("targets", [])
@torch.inference_mode()
# pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
def update(
self: TBinaryRecallAtFixedPrecision,
input: torch.Tensor,
target: torch.Tensor,
) -> TBinaryRecallAtFixedPrecision:
input = input.to(self.device)
target = target.to(self.device)
_binary_recall_at_fixed_precision_update_input_check(
input, target, self.min_precision
)
self.inputs.append(input)
self.targets.append(target)
return self
@torch.inference_mode()
def compute(
self: TBinaryRecallAtFixedPrecision,
) -> Tuple[torch.Tensor, torch.Tensor]:
return _binary_recall_at_fixed_precision_compute(
torch.cat(self.inputs), torch.cat(self.targets), self.min_precision
)
@torch.inference_mode()
def merge_state(
self: TBinaryRecallAtFixedPrecision,
metrics: Iterable[TBinaryRecallAtFixedPrecision],
) -> TBinaryRecallAtFixedPrecision:
for metric in metrics:
if metric.inputs:
metric_inputs = torch.cat(metric.inputs).to(self.device)
metric_targets = torch.cat(metric.targets).to(self.device)
self.inputs.append(metric_inputs)
self.targets.append(metric_targets)
return self
@torch.inference_mode()
def _prepare_for_merge_state(self: TBinaryRecallAtFixedPrecision) -> None:
if self.inputs and self.targets:
self.inputs = [torch.cat(self.inputs)]
self.targets = [torch.cat(self.targets)]
[docs]class MultilabelRecallAtFixedPrecision(
Metric[Tuple[List[torch.Tensor], List[torch.Tensor]]]
):
"""
Returns the highest possible recall value given the minimum precision
for each label and their corresponding thresholds for multi-label
classification tasks. The maximum recall computation for each label is
equivalent to _binary_recall_at_fixed_precision_compute in BinaryRecallAtFixedPrecision.
Its functional version is :func:`torcheval.metrics.functional.multilabel_recall_at_fixed_precision`.
See also :class:`BinaryRecallAtFixedPrecision <BinaryRecallAtFixedPrecision>`
Args:
num_labels (int): Number of labels
min_precision (float): Minimum precision threshold
Examples::
>>> import torch
>>> from torcheval.metrics import MultilabelRecallAtFixedPrecision
>>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5)
>>> input = torch.tensor([[0.75, 0.05, 0.35], [0.45, 0.75, 0.05], [0.05, 0.55, 0.75], [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1], [0, 0, 0], [0, 1, 1], [1, 1, 1]])
>>> metric.update(input, target)
>>> metric.compute()
([torch.tensor(1.0), torch.tensor(1.0), torch.tensor(1.0)],
[torch.tensor(0.05), torch.tensor(0.55), torch.tensor(0.05)])
"""
[docs] def __init__(
self: TMultilabelRecallAtFixedPrecision,
*,
num_labels: int,
min_precision: float,
device: Optional[torch.device] = None,
) -> None:
super().__init__(device=device)
self.num_labels = num_labels
self.min_precision = min_precision
self._add_state("inputs", [])
self._add_state("targets", [])
@torch.inference_mode()
# pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
def update(
self: TMultilabelRecallAtFixedPrecision,
input: torch.Tensor,
target: torch.Tensor,
) -> TMultilabelRecallAtFixedPrecision:
input = input.to(self.device)
target = target.to(self.device)
_multilabel_recall_at_fixed_precision_update_input_check(
input, target, self.num_labels, self.min_precision
)
self.inputs.append(input)
self.targets.append(target)
return self
@torch.inference_mode()
def compute(
self: TMultilabelRecallAtFixedPrecision,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return _multilabel_recall_at_fixed_precision_compute(
torch.cat(self.inputs),
torch.cat(self.targets),
self.num_labels,
self.min_precision,
)
@torch.inference_mode()
def merge_state(
self: TMultilabelRecallAtFixedPrecision,
metrics: Iterable[TMultilabelRecallAtFixedPrecision],
) -> TMultilabelRecallAtFixedPrecision:
for metric in metrics:
if metric.inputs:
metric_inputs = torch.cat(metric.inputs).to(self.device)
metric_targets = torch.cat(metric.targets).to(self.device)
self.inputs.append(metric_inputs)
self.targets.append(metric_targets)
return self
@torch.inference_mode()
def _prepare_for_merge_state(self: TMultilabelRecallAtFixedPrecision) -> None:
if self.inputs and self.targets:
self.inputs = [torch.cat(self.inputs)]
self.targets = [torch.cat(self.targets)]