Shortcuts

ignite.metrics#

Metrics provide a way to compute various quantities of interest in an online fashion without having to store the entire output history of a model.

In practice a user needs to attach the metric instance to an engine. The metric value is then computed using the output of the engine’s process_function:

def process_function(engine, batch):
    # ...
    return y_pred, y

engine = Engine(process_function)
metric = Accuracy()
metric.attach(engine, "accuracy")

If the engine’s output is not in the format (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y, …}, the user can use the output_transform argument to transform it:

def process_function(engine, batch):
    # ...
    return {'y_pred': y_pred, 'y_true': y, ...}

engine = Engine(process_function)

def output_transform(output):
    # `output` variable is returned by above `process_function`
    y_pred = output['y_pred']
    y = output['y_true']
    return y_pred, y  # output format is according to `Accuracy` docs

metric = Accuracy(output_transform=output_transform)
metric.attach(engine, "accuracy")

Note

Most of implemented metrics are adapted to distributed computations and reduce their internal states across the GPUs before computing metric value. This can be helpful to run the evaluation on multiple nodes/GPU instances with a distributed data sampler. Following code snippet shows in detail how to adapt metrics:

device = "cuda:{}".format(local_rank)
model = torch.nn.parallel.DistributedDataParallel(model,
                                                  device_ids=[local_rank, ],
                                                  output_device=local_rank)
test_sampler = DistributedSampler(test_dataset)
test_loader = DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler,
                         num_workers=num_workers, pin_memory=True)

evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(device=device)}, device=device)

Metric arithmetics#

Metrics could be combined together to form new metrics. This could be done through arithmetics, such as metric1 + metric2, use PyTorch operators, such as (metric1 + metric2).pow(2).mean(), or use a lambda function, such as MetricsLambda(lambda a, b: torch.mean(a + b), metric1, metric2).

For example:

precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()

Note

This example computes the mean of F1 across classes. To combine precision and recall to get F1 or other F metrics, we have to be careful that average=False, i.e. to use the unaveraged precision and recall, otherwise we will not be computing F-beta metrics.

Metrics also support indexing operation (if metric’s result is a vector/matrix/tensor). For example, this can be useful to compute mean metric (e.g. precision, recall or IoU) ignoring the background:

cm = ConfusionMatrix(num_classes=10)
iou_metric = IoU(cm)
iou_no_bg_metric = iou_metric[:9]  # We assume that the background index is 9
mean_iou_no_bg_metric = iou_no_bg_metric.mean()
# mean_iou_no_bg_metric.compute() -> tensor(0.12345)

How to create a custom metric#

To create a custom metric one needs to create a new class inheriting from Metric and override three methods :

  • reset() : resets internal variables and accumulators

  • update(output) : updates internal variables and accumulators with provided batch output (y_pred, y)

  • compute() : computes custom metric and return the result

For example, we would like to implement for illustration purposes a multi-class accuracy metric with some specific condition (e.g. ignore user-defined classes):

from ignite.metrics import Metric
from ignite.exceptions import NotComputableError

# These decorators helps with distributed settings
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced


class CustomAccuracy(Metric):

    def __init__(self, ignored_class, output_transform=lambda x: x, device=None):
        self.ignored_class = ignored_class
        self._num_correct = None
        self._num_examples = None
        super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)

    @reinit__is_reduced
    def reset(self):
        self._num_correct = 0
        self._num_examples = 0
        super(CustomAccuracy, self).reset()

    @reinit__is_reduced
    def update(self, output):
        y_pred, y = output

        indices = torch.argmax(y_pred, dim=1)

        mask = (y != self.ignored_class)
        mask &= (indices != self.ignored_class)
        y = y[mask]
        indices = indices[mask]
        correct = torch.eq(indices, y).view(-1)

        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

    @sync_all_reduce("_num_examples", "_num_correct")
    def compute(self):
        if self._num_examples == 0:
            raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
        return self._num_correct / self._num_examples

We imported necessary classes as Metric, NotComputableError and decorators to adapt the metric for distributed setting. In reset method, we reset internal variables _num_correct and _num_examples which are used to compute the custom metric. In updated method we define how to update the internal variables. And finally in compute method, we compute metric value.

We can check this implementation in a simple case:

import torch
torch.manual_seed(8)

m = CustomAccuracy(ignored_class=3)

batch_size = 4
num_classes = 5

y_pred = torch.rand(batch_size, num_classes)
y = torch.randint(0, num_classes, size=(batch_size, ))

m.update((y_pred, y))
res = m.compute()

print(y, torch.argmax(y_pred, dim=1))
# Out: tensor([2, 2, 2, 3]) tensor([2, 1, 0, 0])

print(m._num_correct, m._num_examples, res)
# Out: 1 3 0.3333333333333333

Metrics and distributed computations#

In the above example, CustomAccuracy constructor has device argument and reset, update, compute methods are decorated with reinit__is_reduced, sync_all_reduce. The purpose of these features is to adapt metrics in distributed computations on CUDA devices and assuming the backend to support “all_reduce” operation. User can specify the device (by default, cuda) at metric’s initialization. This device _can_ be used to store internal variables on and to collect all results from all participating devices. More precisely, in the above example we added @sync_all_reduce(“_num_examples”, “_num_correct”) over compute method. This means that when compute method is called, metric’s interal variables self._num_examples and self._num_correct are summed up over all participating devices. Therefore, once collected, these internal variables can be used to compute the final metric value.

Complete list of metrics#

class ignite.metrics.Accuracy(output_transform=<function Accuracy.<lambda>>, is_multilabel=False, device=None)[source]#

Calculates the accuracy for binary, multiclass and multilabel data.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

  • y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).

  • y must be in the following shape (batch_size, …).

  • y and y_pred must be in the following shape of (batch_size, num_categories, …) for multilabel cases.

In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:

def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

binary_accuracy = Accuracy(thresholded_output_transform)
Parameters
  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • is_multilabel (bool, optional) – flag to use in multilabel case. By default, False.

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

class ignite.metrics.Average(output_transform=<function Average.<lambda>>, device=None)[source]#

Helper class to compute arithmetic average of a single variable.

  • update must receive output of the form x.

  • x can be a number or torch.Tensor.

Note

Number of samples is updated following the rule:

  • +1 if input is a number

  • +1 if input is a 1D torch.Tensor

  • +batch_size if input is an ND torch.Tensor. Batch size is the first dimension (shape[0]).

For input x being an ND torch.Tensor with N > 1, the first dimension is seen as the number of samples and is summed up and added to the accumulator: accumulator += x.sum(dim=0)

Examples:

evaluator = ...

custom_var_mean = Average(output_transform=lambda output: output['custom_var'])
custom_var_mean.attach(evaluator, 'mean_custom_var')

state = evaluator.run(dataset)
# state.metrics['mean_custom_var'] -> average of output['custom_var']
Parameters
  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • device (str of torch.device) – device specification in case of distributed computation usage. In most of the cases, it should defined as “cuda:local_rank”.

class ignite.metrics.ConfusionMatrix(num_classes, average=None, output_transform=<function ConfusionMatrix.<lambda>>, device=None)[source]#

Calculates confusion matrix for multi-class data.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

  • y_pred must contain logits and has the following shape (batch_size, num_categories, …)

  • y should have the following shape (batch_size, …) and contains ground-truth class indices

    with or without the background class. During the computation, argmax of y_pred is taken to determine predicted classes.

Parameters
  • num_classes (int) – number of classes. See notes for more details.

  • average (str, optional) – confusion matrix values averaging schema: None, “samples”, “recall”, “precision”. Default is None. If average=”samples” then confusion matrix values are normalized by the number of seen samples. If average=”recall” then confusion matrix values are normalized such that diagonal values represent class recalls. If average=”precision” then confusion matrix values are normalized such that diagonal values represent class precisions.

  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

Note

In case of the targets y in (batch_size, …) format, target indices between 0 and num_classes only contribute to the confusion matrix and others are neglected. For example, if num_classes=20 and target index equal 255 is encountered, then it is filtered out.

ignite.metrics.DiceCoefficient(cm, ignore_index=None)[source]#

Calculates Dice Coefficient for a given ConfusionMatrix metric.

Parameters
  • cm (ConfusionMatrix) – instance of confusion matrix metric

  • ignore_index (int, optional) – index to ignore, e.g. background index

class ignite.metrics.EpochMetric(compute_fn, output_transform=<function EpochMetric.<lambda>>)[source]#

Class for metrics that should be computed on the entire output history of a model. Model’s output and targets are restricted to be of shape (batch_size, n_classes). Output datatype should be float32. Target datatype should be long.

Warning

Current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.

Warning

Current implementation does not work with distributed computations. Results are not gather across all devices and computed results are valid for a single device only.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

If target shape is (batch_size, n_classes) and n_classes > 1 than it should be binary: e.g. [[0, 1, 0, 1], ].

Parameters
  • compute_fn (callable) – a callable with the signature (torch.tensor, torch.tensor) takes as the input predictions and targets and returns a scalar.

  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

ignite.metrics.Fbeta(beta, average=True, precision=None, recall=None, output_transform=None, device=None)[source]#

Calculates F-beta score

Parameters
  • beta (float) – weight of precision in harmonic mean

  • average (bool, optional) – if True, F-beta score is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with F-beta score for each class in multiclass case.

  • precision (Precision, optional) – precision object metric with average=False to compute F-beta score

  • recall (Precision, optional) – recall object metric with average=False to compute F-beta score

  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. It is used only if precision or recall are not provided.

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

Returns

MetricsLambda, F-beta metric

class ignite.metrics.GeometricAverage(output_transform=<function GeometricAverage.<lambda>>, device=None)[source]#

Helper class to compute geometric average of a single variable.

  • update must receive output of the form x.

  • x can be a number or torch.Tensor.

Note

Number of samples is updated following the rule:

  • +1 if input is a number

  • +1 if input is a 1D torch.Tensor

  • +batch_size if input is a ND torch.Tensor. Batch size is the first dimension (shape[0]).

For input x being an ND torch.Tensor with N > 1, the first dimension is seen as the number of samples and is aggregated and added to the accumulator: accumulator *= prod(x, dim=0)

Parameters
  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • device (str of torch.device) – device specification in case of distributed computation usage. In most of the cases, it should defined as “cuda:local_rank”.

ignite.metrics.IoU(cm, ignore_index=None)[source]#

Calculates Intersection over Union using ConfusionMatrix metric.

Parameters
  • cm (ConfusionMatrix) – instance of confusion matrix metric

  • ignore_index (int, optional) – index to ignore, e.g. background index

Returns

MetricsLambda

Examples:

train_evaluator = ...

cm = ConfusionMatrix(num_classes=num_classes)
IoU(cm, ignore_index=0).attach(train_evaluator, 'IoU')

state = train_evaluator.run(train_dataset)
# state.metrics['IoU'] -> tensor of shape (num_classes - 1, )
ignite.metrics.mIoU(cm, ignore_index=None)[source]#

Calculates mean Intersection over Union using ConfusionMatrix metric.

Parameters
  • cm (ConfusionMatrix) – instance of confusion matrix metric

  • ignore_index (int, optional) – index to ignore, e.g. background index

Returns

MetricsLambda

Examples:

train_evaluator = ...

cm = ConfusionMatrix(num_classes=num_classes)
mIoU(cm, ignore_index=0).attach(train_evaluator, 'mean IoU')

state = train_evaluator.run(train_dataset)
# state.metrics['mean IoU'] -> scalar
class ignite.metrics.Loss(loss_fn, output_transform=<function Loss.<lambda>>, batch_size=<function Loss.<lambda>>, device=None)[source]#

Calculates the average loss according to the passed loss_fn.

Parameters
  • loss_fn (callable) – a callable taking a prediction tensor, a target tensor, optionally other arguments, and returns the average loss over all observations in the batch.

  • output_transform (callable) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. The output is expected to be a tuple (prediction, target) or (prediction, target, kwargs) where kwargs is a dictionary of extra keywords arguments. If extra keywords arguments are provided they are passed to loss_fn.

  • batch_size (callable) – a callable taking a target tensor that returns the first dimension size (usually the batch size).

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

class ignite.metrics.MeanAbsoluteError(output_transform=<function Metric.<lambda>>, device=None)[source]#

Calculates the mean absolute error.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

class ignite.metrics.MeanPairwiseDistance(p=2, eps=1e-06, output_transform=<function MeanPairwiseDistance.<lambda>>, device=None)[source]#

Calculates the mean pairwise distance: average of pairwise distances computed on provided batches.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

class ignite.metrics.MeanSquaredError(output_transform=<function Metric.<lambda>>, device=None)[source]#

Calculates the mean squared error.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

class ignite.metrics.Metric(output_transform=<function Metric.<lambda>>, device=None)[source]#

Base class for all Metrics.

Parameters
  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

abstract compute()[source]#

Computes the metric based on it’s accumulated state.

This is called at the end of each epoch.

Returns

the actual quantity of interest.

Return type

Any

Raises

NotComputableError – raised when the metric cannot be computed.

abstract reset()[source]#

Resets the metric to it’s initial state.

This is called at the start of each epoch.

abstract update(output)[source]#

Updates the metric’s state using the passed batch output.

This is called once for each batch.

Parameters

output – the is the output from the engine’s process function.

class ignite.metrics.MetricsLambda(f, *args, **kwargs)[source]#

Apply a function to other metrics to obtain a new metric. The result of the new metric is defined to be the result of applying the function to the result of argument metrics.

When update, this metric does not recursively update the metrics it depends on. When reset, all its dependency metrics would be resetted. When attach, all its dependencies would be automatically attached.

Parameters
  • f (callable) – the function that defines the computation

  • args (sequence) – Sequence of other metrics or something else that will be fed to f as arguments.

Example:

precision = Precision(average=False)
recall = Recall(average=False)

def Fbeta(r, p, beta):
    return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item()

F1 = MetricsLambda(Fbeta, recall, precision, 1)
F2 = MetricsLambda(Fbeta, recall, precision, 2)
F3 = MetricsLambda(Fbeta, recall, precision, 3)
F4 = MetricsLambda(Fbeta, recall, precision, 4)
class ignite.metrics.Precision(output_transform=<function Precision.<lambda>>, average=False, is_multilabel=False, device=None)[source]#

Calculates precision for binary and multiclass data.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

  • y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).

  • y must be in the following shape (batch_size, …).

In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:

def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

precision = Precision(output_transform=thresholded_output_transform)

In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for example, average parameter should be False. This can be done as shown below:

precision = Precision(average=False, is_multilabel=True)
recall = Recall(average=False, is_multilabel=True)
F1 = precision * recall * 2 / (precision + recall + 1e-20)
F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)

Warning

In multilabel cases, if average is False, current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.

Warning

In multilabel cases, if average is False, current implementation does not work with distributed computations. Results are not reduced across the GPUs. Computed result corresponds to the local rank’s (single GPU) result.

Parameters
  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • average (bool, optional) – if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).

  • is_multilabel (bool, optional) – parameter should be True and the average is computed across samples, instead of classes.

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

class ignite.metrics.Recall(output_transform=<function Recall.<lambda>>, average=False, is_multilabel=False, device=None)[source]#

Calculates recall for binary and multiclass data.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

  • y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).

  • y must be in the following shape (batch_size, …).

In binary and multilabel cases, the elements of y and y_pred should have 0 or 1 values. Thresholding of predictions can be done as below:

def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

recall = Recall(output_transform=thresholded_output_transform)

In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for example, average parameter should be False. This can be done as shown below:

precision = Precision(average=False, is_multilabel=True)
recall = Recall(average=False, is_multilabel=True)
F1 = precision * recall * 2 / (precision + recall + 1e-20)
F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)

Warning

In multilabel cases, if average is False, current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.

Warning

In multilabel cases, if average is False, current implementation does not work with distributed computations. Results are not reduced across the GPUs. Computed result corresponds to the local rank’s (single GPU) result.

Parameters
  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • average (bool, optional) – if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).

  • is_multilabel (bool, optional) – parameter should be True and the average is computed across samples, instead of classes.

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

class ignite.metrics.RootMeanSquaredError(output_transform=<function Metric.<lambda>>, device=None)[source]#

Calculates the root mean squared error.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

class ignite.metrics.RunningAverage(src=None, alpha=0.98, output_transform=None, epoch_bound=True, device=None)[source]#

Compute running average of a metric or the output of process function.

Parameters
  • src (Metric or None) – input source: an instance of Metric or None. The latter corresponds to engine.state.output which holds the output of process function.

  • alpha (float, optional) – running average decay factor, default 0.98

  • output_transform (callable, optional) – a function to use to transform the output if src is None and corresponds the output of process function. Otherwise it should be None.

  • epoch_bound (boolean, optional) – whether the running average should be reset after each epoch (defaults to True).

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. This is necessary when running average is computed on the output of process function. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.

Examples:

alpha = 0.98
acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha)
acc_metric.attach(trainer, 'running_avg_accuracy')

avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha)
avg_output.attach(trainer, 'running_avg_loss')

@trainer.on(Events.ITERATION_COMPLETED)
def log_running_avg_metrics(engine):
    print("running avg accuracy:", engine.state.metrics['running_avg_accuracy'])
    print("running avg loss:", engine.state.metrics['running_avg_loss'])
class ignite.metrics.TopKCategoricalAccuracy(k=5, output_transform=<function TopKCategoricalAccuracy.<lambda>>, device=None)[source]#

Calculates the top-k categorical accuracy.

  • update must receive output of the form (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y}.

class ignite.metrics.VariableAccumulation(op, output_transform=<function VariableAccumulation.<lambda>>, device=None)[source]#

Single variable accumulator helper to compute (arithmetic, geometric, harmonic) average of a single variable.

  • update must receive output of the form x.

  • x can be a number or torch.Tensor.

Note

The class stores input into two public variables: accumulator and num_examples. Number of samples is updated following the rule:

  • +1 if input is a number

  • +1 if input is a 1D torch.Tensor

  • +batch_size if input is a ND torch.Tensor. Batch size is the first dimension (shape[0]).

Parameters
  • op (callable) – a callable to update accumulator. Method’s signature is (accumulator, output). For example, to compute arithmetic mean value, op = lambda a, x: a + x.

  • output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.

  • device (str of torch.device, optional) – device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set torch.cuda.set_device(local_rank). By default, if a distributed process group is initialized and available, device is set to cuda.