Shortcuts

torcheval.metrics.functional.multilabel_auprc¶

torcheval.metrics.functional.multilabel_auprc(input: Tensor, target: Tensor, num_labels: int | None = None, *, average: str | None = 'macro')

Compute AUPRC, also called Average Precision, which is the area under the Precision-Recall Curve, for multilabel classification. Its class version is torcheval.metrics.MultilabelAUPRC.

Precision is defined as $$\frac{T_p}{T_p+F_p}$$, it is the probability that a positive prediction from the model is a true positive. Recall is defined as $$\frac{T_p}{T_p+F_n}$$, it is the probability that a true positive is predicted to be positive by the model.

The precision-recall curve plots the recall on the x axis and the precision on the y axis, both of which are bounded between 0 and 1. This function returns the area under that graph. If the area is near one, the model supports a threshold which correctly identifies a high percentage of true positives while also rejecting enough false examples so that most of the true predictions are true positives.

In the multilabel version of AUPRC, the input and target tensors are 2-dimensional. The rows of each tensor are associated with a particular example and the columns are associated with a particular class.

For the target tensor, the entry of the r’th row and c’th column (r and c are 0-indexed) is 1 if the r’th example belongs to the c’th class, and 0 if not. For the input tensor, the entry in the same position is the output of the classification model prediciting the inclusion of the r’th example in the c’th class. Note that in the multilabel setting, multiple labels are allowed to apply to a single sample. This stands in contrast to the multiclass sample, in which there may be more than 2 distinct classes but each sample must have exactly one class.

The results of N label multilabel auprc without an average is equivalent to binary auprc with N tasks if:

1. the input is transposed, in binary labelification examples are associated with columns, whereas they are associated with rows in multilabel classification.

2. the target is transposed for the same reason

See examples below for more details on the connection between Multilabel and Binary AUPRC.

Parameters:
• input (Tensor) – Tensor of label predictions It should be probabilities or logits with shape of (n_sample, n_label).

• target (Tensor) – Tensor of ground truth labels with shape of (n_samples, n_label).

• num_labels (int) – Number of labels.

• average (str, optional) –

• 'macro' [default]:

Calculate metrics for each class separately, and return their unweighted mean.

• None or 'none':

Calculate the metric for each class separately, and return the metric for every class.

Examples::
>>> import torch
>>> from torcheval.metrics.functional import multilabel_auprc
>>> input = torch.tensor([[0.75, 0.05, 0.35], [0.45, 0.75, 0.05], [0.05, 0.55, 0.75], [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1], [0, 0, 0], [0, 1, 1], [1, 1, 1]])
>>> multilabel_auprc(input, target, num_labels=3, average=None)
tensor([0.7500, 0.5833, 0.9167])
>>> multilabel_auprc(input, target, average=None)
tensor([0.7500, 0.5833, 0.9167])
>>> multilabel_auprc(input, target, num_labels=3, average='macro')
tensor(0.7500)
>>> multilabel_auprc(input, target, num_labels=3)
tensor(0.7500)
>>> multilabel_auprc(input, target, average='macro')
tensor(0.7500)
>>> multilabel_auprc(input, target)
tensor(0.7500)


Connection to BinaryAUPRC >>> input = torch.tensor([[0.1, 0, 0], [0, 1, 0], [0.1, 0.2, 0.7], [0, 0, 1]]) >>> target = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]]) >>> multilabel_auprc(input, target) tensor([0.5000, 1.0000, 1.0000])

the above is equivalent to >>> from torcheval.metrics import BinaryAUPRC >>> input = torch.tensor([[0.1, 0, 0.1, 0], [0, 1, 0.2, 0], [0, 0, 0.7, 1]]) >>> target = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1]]) >>> binary_auprc(input, target, num_tasks=3) tensor([0.5000, 1.0000, 1.0000])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials