Precision#
- class ignite.metrics.precision.Precision(output_transform=<function _BasePrecisionRecall.<lambda>>, average=False, is_multilabel=False, device=device(type='cpu'))[source]#
Calculates precision for binary, multiclass and multilabel data.
where is true positives and is false positives.
update
must receive output of the form(y_pred, y)
.y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).
y must be in the following shape (batch_size, …).
- Parameters
output_transform (Callable) – a callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.average (Optional[Union[bool, str]]) –
available options are
- False
default option. For multicalss and multilabel inputs, per class and per label metric is returned respectively.
- None
like False option except that per class metric is returned for binary data as well. For compatibility with Scikit-Learn api.
- ’micro’
Metric is computed counting stats of classes/labels altogether.
where is the number of classes/labels (2 in binary case). in and means that the measures are computed for class/label (in a one-vs-rest sense in multiclass case).
For binary and multiclass inputs, this is equivalent with accuracy, so use
Accuracy
.- ’samples’
for multilabel input, at first, precision is computed on a per sample basis and then average across samples is returned.
where is the number of samples. in and means that the measures are computed for sample , across labels.
Incompatible with binary and multiclass inputs.
- ’weighted’
like macro precision but considers class/label imbalance. for binary and multiclass input, it computes metric for each class then returns average of them weighted by support of classes (number of actual samples in each class). For multilabel input, it computes precision for each label then returns average of them weighted by support of labels (number of actual positive samples in each label).
where is the number of classes (2 in binary case). is the number of samples belonged to class in binary and multiclass case, and the number of positive samples belonged to label in multilabel case.
- macro
computes macro precision which is unweighted average of metric computed across classes/labels.
where is the number of classes (2 in binary case).
- True
like macro option. For backward compatibility.
is_multilabel (bool) – flag to use in multilabel case. By default, value is False.
device (Union[str, device]) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your
update
arguments ensures theupdate
method is non-blocking. By default, CPU.
Examples
For more information on how metric works with
Engine
, visit Attach Engine API.from collections import OrderedDict import torch from torch import nn, optim from ignite.engine import * from ignite.handlers import * from ignite.metrics import * from ignite.utils import * from ignite.contrib.metrics.regression import * from ignite.contrib.metrics import * # create default evaluator for doctests def eval_step(engine, batch): return batch default_evaluator = Engine(eval_step) # create default optimizer for doctests param_tensor = torch.zeros([1], requires_grad=True) default_optimizer = torch.optim.SGD([param_tensor], lr=0.1) # create default trainer for doctests # as handlers could be attached to the trainer, # each test must define his own trainer using `.. testsetup:` def get_default_trainer(): def train_step(engine, batch): return batch return Engine(train_step) # create default model for doctests default_model = nn.Sequential(OrderedDict([ ('base', nn.Linear(4, 2)), ('fc', nn.Linear(2, 1)) ])) manual_seed(666)
Binary case. In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values.
metric = Precision() weighted_metric = Precision(average='weighted') two_class_metric = Precision(average=None) # Returns precision for both classes metric.attach(default_evaluator, "precision") weighted_metric.attach(default_evaluator, "weighted precision") two_class_metric.attach(default_evaluator, "both classes precision") y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) state = default_evaluator.run([[y_pred, y_true]]) print(f"Precision: {state.metrics['precision']}") print(f"Weighted Precision: {state.metrics['weighted precision']}") print(f"Precision for class 0 and class 1: {state.metrics['both classes precision']}")
Precision: 0.75 Weighted Precision: 0.6666666666666666 Precision for class 0 and class 1: tensor([0.5000, 0.7500], dtype=torch.float64)
Multiclass case
metric = Precision() macro_metric = Precision(average=True) weighted_metric = Precision(average='weighted') metric.attach(default_evaluator, "precision") macro_metric.attach(default_evaluator, "macro precision") weighted_metric.attach(default_evaluator, "weighted precision") y_true = torch.tensor([2, 0, 2, 1, 0]) y_pred = torch.tensor([ [0.0266, 0.1719, 0.3055], [0.6886, 0.3978, 0.8176], [0.9230, 0.0197, 0.8395], [0.1785, 0.2670, 0.6084], [0.8448, 0.7177, 0.7288] ]) state = default_evaluator.run([[y_pred, y_true]]) print(f"Precision: {state.metrics['precision']}") print(f"Macro Precision: {state.metrics['macro precision']}") print(f"Weighted Precision: {state.metrics['weighted precision']}")
Precision: tensor([0.5000, 0.0000, 0.3333], dtype=torch.float64) Macro Precision: 0.27777777777777773 Weighted Precision: 0.3333333333333333
Multilabel case, the shapes must be (batch_size, num_labels, …)
metric = Precision(is_multilabel=True) micro_metric = Precision(is_multilabel=True, average='micro') macro_metric = Precision(is_multilabel=True, average=True) weighted_metric = Precision(is_multilabel=True, average='weighted') samples_metric = Precision(is_multilabel=True, average='samples') metric.attach(default_evaluator, "precision") micro_metric.attach(default_evaluator, "micro precision") macro_metric.attach(default_evaluator, "macro precision") weighted_metric.attach(default_evaluator, "weighted precision") samples_metric.attach(default_evaluator, "samples precision") y_true = torch.tensor([ [0, 0, 1], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 1], ]) y_pred = torch.tensor([ [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], ]) state = default_evaluator.run([[y_pred, y_true]]) print(f"Precision: {state.metrics['precision']}") print(f"Micro Precision: {state.metrics['micro precision']}") print(f"Macro Precision: {state.metrics['macro precision']}") print(f"Weighted Precision: {state.metrics['weighted precision']}") print(f"Samples Precision: {state.metrics['samples precision']}")
Precision: tensor([0.2000, 0.5000, 0.0000], dtype=torch.float64) Micro Precision: 0.2222222222222222 Macro Precision: 0.2333333333333333 Weighted Precision: 0.175 Samples Precision: 0.2
Thresholding of predictions can be done as below:
def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y metric = Precision(output_transform=thresholded_output_transform) metric.attach(default_evaluator, "precision") y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.tensor([0.6, 0.2, 0.9, 0.4, 0.7, 0.65]) state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics["precision"])
0.75
Changed in version 0.4.10: Some new options were added to average parameter.
Methods
Updates the metric's state using the passed batch output.