Shortcuts

Precision

class ignite.metrics.precision.Precision(output_transform=<function Precision.<lambda>>, average=False, is_multilabel=False, device=device(type='cpu'))[source]

Calculates precision for binary and multiclass data.

Precision=TPTP+FP\text{Precision} = \frac{ TP }{ TP + FP }

where TP\text{TP} is true positives and FP\text{FP} is false positives.

  • update must receive output of the form (y_pred, y) or {'y_pred': y_pred, 'y': y}.

  • y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).

  • y must be in the following shape (batch_size, …).

Parameters
  • output_transform (Callable) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • average (bool) – if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).

  • is_multilabel (bool) – flag to use in multilabel case. By default, value is False. If True, average parameter should be True and the average is computed across samples, instead of classes.

  • device (Union[str, device]) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your update arguments ensures the update method is non-blocking. By default, CPU.

Examples

For more information on how metric works with Engine, visit Attach Engine API.

Show default setup
from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

# create default optimizer for doctests

param_tensor = torch.zeros([1], requires_grad=True)
default_optimizer = torch.optim.SGD([param_tensor], lr=0.1)

# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must define his own trainer using `.. testsetup:`

def get_default_trainer():

    def train_step(engine, batch):
        return batch

    return Engine(train_step)

# create default model for doctests

default_model = nn.Sequential(OrderedDict([
    ('base', nn.Linear(4, 2)),
    ('fc', nn.Linear(2, 1))
]))

manual_seed(666)

Binary case

metric = Precision(average=False)
metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.Tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
0.75

Multiclass case

metric = Precision(average=False)
metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long()
y_pred = torch.Tensor([
    [0.0266, 0.1719, 0.3055],
    [0.6886, 0.3978, 0.8176],
    [0.9230, 0.0197, 0.8395],
    [0.1785, 0.2670, 0.6084],
    [0.8448, 0.7177, 0.7288],
    [0.7748, 0.9542, 0.8573],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
tensor([0.5000, 1.0000, 0.3333], dtype=torch.float64)

Precision can be computed as the unweighted average across all classes:

metric = Precision(average=True)
metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([2, 0, 2, 1, 0, 1]).long()
y_pred = torch.Tensor([
    [0.0266, 0.1719, 0.3055],
    [0.6886, 0.3978, 0.8176],
    [0.9230, 0.0197, 0.8395],
    [0.1785, 0.2670, 0.6084],
    [0.8448, 0.7177, 0.7288],
    [0.7748, 0.9542, 0.8573],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
0.6111...

Multilabel case, the shapes must be (batch_size, num_categories, …)

metric = Precision(is_multilabel=True)
metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([
    [0, 0, 1],
    [0, 0, 0],
    [0, 0, 0],
    [1, 0, 0],
    [0, 1, 1],
]).unsqueeze(0)
y_pred = torch.Tensor([
    [1, 1, 0],
    [1, 0, 1],
    [1, 0, 0],
    [1, 0, 1],
    [1, 1, 0],
]).unsqueeze(0)
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
tensor([0.2000, 0.5000, 0.0000], dtype=torch.float64)

In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:

def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

metric = Precision(average=False, output_transform=thresholded_output_transform)
metric.attach(default_evaluator, "precision")
y_true = torch.Tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.Tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["precision"])
0.75

In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for example, average parameter should be False. This can be done as shown below:

precision = Precision(average=False)
recall = Recall(average=False)
F1 = precision * recall * 2 / (precision + recall + 1e-20)
F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)

Warning

In multilabel cases, if average is False, current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.

Methods

update

Updates the metric's state using the passed batch output.

update(output)[source]

Updates the metric’s state using the passed batch output.

By default, this is called once for each batch.

Parameters

output (Sequence[Tensor]) – the is the output from the engine’s process function.

Return type

None