torcheval.metrics.functional.multiclass_recall¶
- torcheval.metrics.functional.multiclass_recall(input: Tensor, target: Tensor, *, num_classes: int | None = None, average: str | None = 'micro') Tensor ¶
Compute recall score, which is calculated as the ratio between the number of true positives (TP) and the total number of actual positives (TP + FN). Its class version is
torcheval.metrics.MultiClassRecall
.- Parameters:
input (Tensor) – Tensor of label predictions It could be the predicted labels, with shape of (n_sample, ). It could also be probabilities or logits with shape of (n_sample, n_class).
torch.argmax
will be used to convert input into predicted labels.target (Tensor) – Tensor of ground truth labels with shape of (n_sample, ).
num_classes – Number of classes.
average –
'micro'
[default]:Calculate the metrics globally, by using the total true positives and false negatives across all classes.
'macro'
:Calculate metrics for each class separately, and return their unweighted mean. Classes with 0 true and predicted instances are ignored.
'weighted'
:Calculate metrics for each class separately, and return their average weighted by the number of instances for each class in the
target
tensor. Classes with 0 true and predicted instances are ignored.
None
:Calculate the metric for each class separately, and return the metric for every class.
Examples:
>>> import torch >>> from torcheval.metrics.functional.classification import multiclass_recall >>> input = torch.tensor([0, 2, 1, 3]) >>> target = torch.tensor([0, 1, 2, 3]) >>> multiclass_recall(input, target) tensor(0.5000) >>> multiclass_recall(input, target, average=None, num_classes=4) tensor([1., 0., 0., 1.]) >>> multiclass_recall(input, target, average="macro", num_classes=4) tensor(0.5000) >>> input = torch.tensor([[0.9, 0.1, 0, 0], [0.1, 0.2, 0.4, 0.3], [0, 1.0, 0, 0], [0, 0, 0.2, 0.8]]) >>> multiclass_recall(input, target) tensor(0.5000)