Shortcuts

Accuracy#

class ignite.metrics.Accuracy(output_transform=<function Accuracy.<lambda>>, is_multilabel=False, device=device(type='cpu'))[source]#

Calculates the accuracy for binary, multiclass and multilabel data.

Accuracy=TP+TNTP+TN+FP+FN\text{Accuracy} = \frac{ TP + TN }{ TP + TN + FP + FN }

where TP\text{TP} is true positives, TN\text{TN} is true negatives, FP\text{FP} is false positives and FN\text{FN} is false negatives.

  • update must receive output of the form (y_pred, y).

  • y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).

  • y must be in the following shape (batch_size, …).

  • y and y_pred must be in the following shape of (batch_size, num_categories, …) and num_categories must be greater than 1 for multilabel cases.

Parameters
  • output_transform (Callable) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • is_multilabel (bool) – flag to use in multilabel case. By default, False.

  • device (Union[str, device]) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your update arguments ensures the update method is non-blocking. By default, CPU.

Examples

For more information on how metric works with Engine, visit Attach Engine API.

from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

# create default optimizer for doctests

param_tensor = torch.zeros([1], requires_grad=True)
default_optimizer = torch.optim.SGD([param_tensor], lr=0.1)

# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must define his own trainer using `.. testsetup:`

def get_default_trainer():

    def train_step(engine, batch):
        return batch

    return Engine(train_step)

# create default model for doctests

default_model = nn.Sequential(OrderedDict([
    ('base', nn.Linear(4, 2)),
    ('fc', nn.Linear(2, 1))
]))

manual_seed(666)

Binary case

metric = Accuracy()
metric.attach(default_evaluator, "accuracy")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["accuracy"])
0.6666...

Multiclass case

metric = Accuracy()
metric.attach(default_evaluator, "accuracy")
y_true = torch.tensor([2, 0, 2, 1, 0, 1])
y_pred = torch.tensor([
    [0.0266, 0.1719, 0.3055],
    [0.6886, 0.3978, 0.8176],
    [0.9230, 0.0197, 0.8395],
    [0.1785, 0.2670, 0.6084],
    [0.8448, 0.7177, 0.7288],
    [0.7748, 0.9542, 0.8573],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["accuracy"])
0.5

Multilabel case

metric = Accuracy(is_multilabel=True)
metric.attach(default_evaluator, "accuracy")
y_true = torch.tensor([
    [0, 0, 1, 0, 1],
    [1, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [1, 0, 0, 0, 1],
    [0, 1, 1, 0, 1],
])
y_pred = torch.tensor([
    [1, 1, 0, 0, 0],
    [1, 0, 1, 0, 0],
    [1, 0, 0, 0, 0],
    [1, 0, 1, 1, 1],
    [1, 1, 0, 0, 1],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["accuracy"])
0.2

In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:

def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

metric = Accuracy(output_transform=thresholded_output_transform)
metric.attach(default_evaluator, "accuracy")
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["accuracy"])
0.6666...

Methods

compute

Computes the metric based on it's accumulated state.

reset

Resets the metric to it's initial state.

update

Updates the metric's state using the passed batch output.

compute()[source]#

Computes the metric based on it’s accumulated state.

By default, this is called at the end of each epoch.

Returns

the actual quantity of interest. However, if a Mapping is returned, it will be (shallow) flattened into engine.state.metrics when completed() is called.

Return type

Any

Raises

NotComputableError – raised when the metric cannot be computed.

reset()[source]#

Resets the metric to it’s initial state.

By default, this is called at the start of each epoch.

Return type

None

update(output)[source]#

Updates the metric’s state using the passed batch output.

By default, this is called once for each batch.

Parameters

output (Sequence[Tensor]) – the is the output from the engine’s process function.

Return type

None