Recall#
- class ignite.metrics.recall.Recall(output_transform=<function Recall.<lambda>>, average=False, is_multilabel=False, device=device(type='cpu'))[source]#
Calculates recall for binary and multiclass data.
where is true positives and is false negatives.
update
must receive output of the form(y_pred, y)
or{'y_pred': y_pred, 'y': y}
.y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).
y must be in the following shape (batch_size, …).
- Parameters
output_transform (Callable) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.average (bool) – if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
is_multilabel (bool) – flag to use in multilabel case. By default, value is False. If True, average parameter should be True and the average is computed across samples, instead of classes.
device (Union[str, device]) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
Examples
Binary case
metric = Recall(average=False) metric.attach(default_evaluator, "recall") y_true = torch.Tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.Tensor([1, 0, 1, 0, 1, 1]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["recall"])
0.75
Multiclass case
metric = Recall(average=False) metric.attach(default_evaluator, "recall") y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long() y_pred = torch.Tensor([ [0.0266, 0.1719, 0.3055], [0.6886, 0.3978, 0.8176], [0.9230, 0.0197, 0.8395], [0.1785, 0.2670, 0.6084], [0.8448, 0.7177, 0.7288], [0.7748, 0.9542, 0.8573], ]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["recall"])
tensor([0.5000, 0.5000, 0.5000], dtype=torch.float64)
Precision can be computed as the unweighted average across all classes:
metric = Recall(average=True) metric.attach(default_evaluator, "recall") y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long() y_pred = torch.Tensor([ [0.0266, 0.1719, 0.3055], [0.6886, 0.3978, 0.8176], [0.9230, 0.0197, 0.8395], [0.1785, 0.2670, 0.6084], [0.8448, 0.7177, 0.7288], [0.7748, 0.9542, 0.8573], ]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["recall"])
0.5
Multilabel case, the shapes must be (batch_size, num_categories, …)
metric = Recall(is_multilabel=True) metric.attach(default_evaluator, "recall") y_true = torch.Tensor([ [0, 0, 1], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 1], ]).unsqueeze(0) y_pred = torch.Tensor([ [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], ]).unsqueeze(0) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["recall"])
tensor([1., 1., 0.], dtype=torch.float64)
In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:
def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y metric = Recall(average=False, output_transform=thresholded_output_transform) metric.attach(default_evaluator, "recall") y_true = torch.Tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.Tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["recall"])
0.75
In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for example, average parameter should be False. This can be done as shown below:
precision = Precision(average=False) recall = Recall(average=False) F1 = precision * recall * 2 / (precision + recall + 1e-20) F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
Warning
In multilabel cases, if average is False, current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.
Methods
Updates the metric's state using the passed batch output.