Source code for ignite.metrics.running_average

from ignite.engine import Events
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

[docs]class RunningAverage(Metric): """Compute running average of a metric or the output of process function. Args: src (Metric or None): input source: an instance of :class:`~ignite.metrics.Metric` or None. The latter corresponds to `engine.state.output` which holds the output of process function. alpha (float, optional): running average decay factor, default 0.98 output_transform (callable, optional): a function to use to transform the output if `src` is None and corresponds the output of process function. Otherwise it should be None. epoch_bound (boolean, optional): whether the running average should be reset after each epoch (defaults to True). device (str of torch.device, optional): device specification in case of distributed computation usage. This is necessary when running average is computed on the output of process function. In most of the cases, it can be defined as "cuda:local_rank" or "cuda" if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is initialized and available, device is set to `cuda`. Examples: .. code-block:: python alpha = 0.98 acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) acc_metric.attach(trainer, 'running_avg_accuracy') avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) avg_output.attach(trainer, 'running_avg_loss') @trainer.on(Events.ITERATION_COMPLETED) def log_running_avg_metrics(engine): print("running avg accuracy:", engine.state.metrics['running_avg_accuracy']) print("running avg loss:", engine.state.metrics['running_avg_loss']) """ _required_output_keys = None def __init__(self, src=None, alpha=0.98, output_transform=None, epoch_bound=True, device=None): if not (isinstance(src, Metric) or src is None): raise TypeError("Argument src should be a Metric or None.") if not (0.0 < alpha <= 1.0): raise ValueError("Argument alpha should be a float between 0.0 and 1.0.") if isinstance(src, Metric): if output_transform is not None: raise ValueError("Argument output_transform should be None if src is a Metric.") if device is not None: raise ValueError("Argument device should be None if src is a Metric.") self.src = src self._get_src_value = self._get_metric_value self.iteration_completed = self._metric_iteration_completed else: if output_transform is None: raise ValueError("Argument output_transform should not be None if src corresponds " "to the output of process function.") self._get_src_value = self._get_output_value self.update = self._output_update self.alpha = alpha self.epoch_bound = epoch_bound super(RunningAverage, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced def reset(self): self._value = None @reinit__is_reduced def update(self, output): # Implement abstract method pass def compute(self): if self._value is None: self._value = self._get_src_value() else: self._value = self._value * self.alpha + (1.0 - self.alpha) * self._get_src_value() return self._value def attach(self, engine, name): if self.epoch_bound: # restart average every epoch engine.add_event_handler(Events.EPOCH_STARTED, self.started) # compute metric engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) # apply running average engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name) def _get_metric_value(self): return self.src.compute() @sync_all_reduce("src") def _get_output_value(self): return self.src def _metric_iteration_completed(self, engine): self.src.started(engine) self.src.iteration_completed(engine) @reinit__is_reduced def _output_update(self, output): self.src = output

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 01/11/2024, 12:19:46 PM.

Built with Sphinx using a theme provided by Read the Docs.