Shortcuts

Recall#

class ignite.metrics.recall.Recall(output_transform=<function _BasePrecisionRecall.<lambda>>, average=False, is_multilabel=False, device=device(type='cpu'))[source]#

Calculates recall for binary, multiclass and multilabel data.

Recall=TPTP+FN\text{Recall} = \frac{ TP }{ TP + FN }

where TP\text{TP} is true positives and FN\text{FN} is false negatives.

  • update must receive output of the form (y_pred, y).

  • y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).

  • y must be in the following shape (batch_size, …).

Parameters
  • output_transform (Callable) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • average (Optional[Union[bool, str]]) –

    available options are

    False

    default option. For multicalss and multilabel inputs, per class and per label metric is returned respectively.

    None

    like False option except that per class metric is returned for binary data as well. For compatibility with Scikit-Learn api.

    ’micro’

    Metric is computed counting stats of classes/labels altogether.

    Micro Recall=k=1CTPkk=1CTPk+FNk\text{Micro Recall} = \frac{\sum_{k=1}^C TP_k}{\sum_{k=1}^C TP_k+FN_k}

    where CC is the number of classes/labels (2 in binary case). kk in TPkTP_k and FNkFN_k (in a one-vs-rest sense in multiclass case).

    For binary and multiclass inputs, this is equivalent with accuracy, so use Accuracy.

    ’samples’

    for multilabel input, at first, recall is computed on a per sample basis and then average across samples is returned.

    Sample-averaged Recall=n=1NTPnTPn+FNnN\text{Sample-averaged Recall} = \frac{\sum_{n=1}^N \frac{TP_n}{TP_n+FN_n}}{N}

    where NN is the number of samples. nn in TPnTP_n and FNnFN_n means that the measures are computed for sample nn, across labels.

    Incompatible with binary and multiclass inputs.

    ’weighted’

    like macro recall but considers class/label imbalance. For binary and multiclass input, it computes metric for each class then returns average of them weighted by support of classes (number of actual samples in each class). For multilabel input, it computes recall for each label then returns average of them weighted by support of labels (number of actual positive samples in each label).

    Recallk=TPkTPk+FNkRecall_k = \frac{TP_k}{TP_k+FN_k}
    Weighted Recall=k=1CPkRecallkN\text{Weighted Recall} = \frac{\sum_{k=1}^C P_k * Recall_k}{N}

    where CC is the number of classes (2 in binary case). PkP_k is the number of samples belonged to class kk in binary and multiclass case, and the number of positive samples belonged to label kk in multilabel case.

    Note that for binary and multiclass data, weighted recall is equivalent with accuracy, so use Accuracy.

    macro

    computes macro recall which is unweighted average of metric computed across classes or labels.

    Macro Recall=k=1CRecallkC\text{Macro Recall} = \frac{\sum_{k=1}^C Recall_k}{C}

    where CC is the number of classes (2 in binary case).

    True

    like macro option. For backward compatibility.

  • is_multilabel (bool) – flag to use in multilabel case. By default, value is False.

  • device (Union[str, device]) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your update arguments ensures the update method is non-blocking. By default, CPU.

Examples

For more information on how metric works with Engine, visit Attach Engine API.

from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.metrics.regression import *
from ignite.utils import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

# create default optimizer for doctests

param_tensor = torch.zeros([1], requires_grad=True)
default_optimizer = torch.optim.SGD([param_tensor], lr=0.1)

# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must define his own trainer using `.. testsetup:`

def get_default_trainer():

    def train_step(engine, batch):
        return batch

    return Engine(train_step)

# create default model for doctests

default_model = nn.Sequential(OrderedDict([
    ('base', nn.Linear(4, 2)),
    ('fc', nn.Linear(2, 1))
]))

manual_seed(666)

Binary case. In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values.

metric = Recall()
two_class_metric = Recall(average=None) # Returns recall for both classes
metric.attach(default_evaluator, "recall")
two_class_metric.attach(default_evaluator, "both classes recall")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
print(f"Recall: {state.metrics['recall']}")
print(f"Recall for class 0 and class 1: {state.metrics['both classes recall']}")
Recall: 0.75
Recall for class 0 and class 1: tensor([0.5000, 0.7500], dtype=torch.float64)

Multiclass case

metric = Recall()
macro_metric = Recall(average=True)

metric.attach(default_evaluator, "recall")
macro_metric.attach(default_evaluator, "macro recall")

y_true = torch.tensor([2, 0, 2, 1, 0])
y_pred = torch.tensor([
    [0.0266, 0.1719, 0.3055],
    [0.6886, 0.3978, 0.8176],
    [0.9230, 0.0197, 0.8395],
    [0.1785, 0.2670, 0.6084],
    [0.8448, 0.7177, 0.7288]
])
state = default_evaluator.run([[y_pred, y_true]])
print(f"Recall: {state.metrics['recall']}")
print(f"Macro Recall: {state.metrics['macro recall']}")
Recall: tensor([0.5000, 0.0000, 0.5000], dtype=torch.float64)
Macro Recall: 0.3333333333333333

Multilabel case, the shapes must be (batch_size, num_categories, …)

metric = Recall(is_multilabel=True)
micro_metric = Recall(is_multilabel=True, average='micro')
macro_metric = Recall(is_multilabel=True, average=True)
samples_metric = Recall(is_multilabel=True, average='samples')

metric.attach(default_evaluator, "recall")
micro_metric.attach(default_evaluator, "micro recall")
macro_metric.attach(default_evaluator, "macro recall")
samples_metric.attach(default_evaluator, "samples recall")

y_true = torch.tensor([
    [0, 0, 1],
    [0, 0, 0],
    [0, 0, 0],
    [1, 0, 0],
    [0, 1, 1],
])
y_pred = torch.tensor([
    [1, 1, 0],
    [1, 0, 1],
    [1, 0, 0],
    [1, 0, 1],
    [1, 1, 0],
])
state = default_evaluator.run([[y_pred, y_true]])
print(f"Recall: {state.metrics['recall']}")
print(f"Micro Recall: {state.metrics['micro recall']}")
print(f"Macro Recall: {state.metrics['macro recall']}")
print(f"Samples Recall: {state.metrics['samples recall']}")
Recall: tensor([1., 1., 0.], dtype=torch.float64)
Micro Recall: 0.5
Macro Recall: 0.6666666666666666
Samples Recall: 0.3

Thresholding of predictions can be done as below:

def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

metric = Recall(output_transform=thresholded_output_transform)
metric.attach(default_evaluator, "recall")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['recall'])
0.75

Changed in version 0.4.10: Some new options were added to average parameter.

Methods

update

Updates the metric's state using the passed batch output.

update(output)[source]#

Updates the metric’s state using the passed batch output.

By default, this is called once for each batch.

Parameters

output (Sequence[Tensor]) – the is the output from the engine’s process function.

Return type

None