Shortcuts

Source code for ignite.metrics.running_average

import warnings
from typing import Any, Callable, cast, Optional, Union

import torch

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, RunningBatchWise, SingleEpochRunningBatchWise

__all__ = ["RunningAverage"]


[docs]class RunningAverage(Metric): """Compute running average of a metric or the output of process function. Args: src: input source: an instance of :class:`~ignite.metrics.metric.Metric` or None. The latter corresponds to `engine.state.output` which holds the output of process function. alpha: running average decay factor, default 0.98 output_transform: a function to use to transform the output if `src` is None and corresponds the output of process function. Otherwise it should be None. epoch_bound: whether the running average should be reset after each epoch. It is depracated in favor of ``usage`` argument in :meth:`attach` method. Setting ``epoch_bound`` to ``False`` is equivalent to ``usage=SingleEpochRunningBatchWise()`` and setting it to ``True`` is equivalent to ``usage=RunningBatchWise()`` in the :meth:`attach` method. Default None. device: specifies which device updates are accumulated on. Should be None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value from the metric is a tensor. Examples: For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. .. include:: defaults.rst :start-after: :orphan: .. testcode:: 1 default_trainer = get_default_trainer() accuracy = Accuracy() metric = RunningAverage(accuracy) metric.attach(default_trainer, 'running_avg_accuracy') @default_trainer.on(Events.ITERATION_COMPLETED) def log_running_avg_metrics(): print(default_trainer.state.metrics['running_avg_accuracy']) y_true = [torch.tensor(y) for y in [[0], [1], [0], [1], [0], [1]]] y_pred = [torch.tensor(y) for y in [[0], [0], [0], [1], [1], [1]]] state = default_trainer.run(zip(y_pred, y_true)) .. testoutput:: 1 1.0 0.98 0.98039... 0.98079... 0.96117... 0.96195... .. testcode:: 2 default_trainer = get_default_trainer() metric = RunningAverage(output_transform=lambda x: x.item()) metric.attach(default_trainer, 'running_avg_accuracy') @default_trainer.on(Events.ITERATION_COMPLETED) def log_running_avg_metrics(): print(default_trainer.state.metrics['running_avg_accuracy']) y = [torch.tensor(y) for y in [[0], [1], [0], [1], [0], [1]]] state = default_trainer.run(y) .. testoutput:: 2 0.0 0.020000... 0.019600... 0.039208... 0.038423... 0.057655... """ required_output_keys = None _state_dict_all_req_keys = ("_value", "src") def __init__( self, src: Optional[Metric] = None, alpha: float = 0.98, output_transform: Optional[Callable] = None, epoch_bound: Optional[bool] = None, device: Optional[Union[str, torch.device]] = None, ): if not (isinstance(src, Metric) or src is None): raise TypeError("Argument src should be a Metric or None.") if not (0.0 < alpha <= 1.0): raise ValueError("Argument alpha should be a float between 0.0 and 1.0.") if isinstance(src, Metric): if output_transform is not None: raise ValueError("Argument output_transform should be None if src is a Metric.") def output_transform(x: Any) -> Any: return x if device is not None: raise ValueError("Argument device should be None if src is a Metric.") self.src: Union[Metric, None] = src device = src._device else: if output_transform is None: raise ValueError( "Argument output_transform should not be None if src corresponds " "to the output of process function." ) self.src = None if device is None: device = torch.device("cpu") if epoch_bound is not None: warnings.warn( "`epoch_bound` is deprecated and will be removed in the future. Consider using `usage` argument of" "`attach` method instead. `epoch_bound=True` is equivalent with `usage=SingleEpochRunningBatchWise()`" " and `epoch_bound=False` is equivalent with `usage=RunningBatchWise()`." ) self.epoch_bound = epoch_bound self.alpha = alpha super(RunningAverage, self).__init__(output_transform=output_transform, device=device)
[docs] @reinit__is_reduced def reset(self) -> None: self._value: Optional[Union[float, torch.Tensor]] = None if isinstance(self.src, Metric): self.src.reset()
[docs] @reinit__is_reduced def update(self, output: Union[torch.Tensor, float]) -> None: if self.src is None: output = output.detach().to(self._device, copy=True) if isinstance(output, torch.Tensor) else output value = idist.all_reduce(output) / idist.get_world_size() else: value = self.src.compute() self.src.reset() if self._value is None: self._value = value else: self._value = self._value * self.alpha + (1.0 - self.alpha) * value
[docs] def compute(self) -> Union[torch.Tensor, float]: return cast(Union[torch.Tensor, float], self._value)
[docs] def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None: r""" Attach the metric to the ``engine`` using the events determined by the ``usage``. Args: engine: the engine to get attached to. name: by which, the metric is inserted into ``engine.state.metrics`` dictionary. usage: the usage determining on which events the metric is reset, updated and computed. It should be an instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table. ======================================================= =========================================== ``usage`` **class** **Description** ======================================================= =========================================== :class:`~.metrics.metric.RunningBatchWise` Running average of the ``src`` metric or ``engine.state.output`` is computed across batches. In the former case, on each batch, ``src`` is reset, updated and computed then its value is retrieved. Default. :class:`~.metrics.metric.SingleEpochRunningBatchWise` Same as above but the running average is computed across batches in an epoch so it is reset at the end of the epoch. :class:`~.metrics.metric.RunningEpochWise` Running average of the ``src`` metric or ``engine.state.output`` is computed across epochs. In the former case, ``src`` works as if it was attached in a :class:`~ignite.metrics.metric.EpochWise` manner and its computed value is retrieved at the end of the epoch. The latter case doesn't make much sense for this usage as the ``engine.state.output`` of the last batch is retrieved then. ======================================================= =========================================== ``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not given and it's computed and updated using ``src``, by manually calling its ``compute`` method, or ``engine.state.output`` at ``usage.COMPLETED`` event. Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by ``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``, otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``. .. versionchanged:: 0.5.1 Added `usage` argument """ usage = self._check_usage(usage) if self.epoch_bound is not None: usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise() if isinstance(self.src, Metric) and not engine.has_event_handler( self.src.iteration_completed, Events.ITERATION_COMPLETED ): engine.add_event_handler(Events.ITERATION_COMPLETED, self.src.iteration_completed) super().attach(engine, name, usage)
[docs] def detach(self, engine: Engine, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None: usage = self._check_usage(usage) if self.epoch_bound is not None: usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise() if isinstance(self.src, Metric) and engine.has_event_handler( self.src.iteration_completed, Events.ITERATION_COMPLETED ): engine.remove_event_handler(self.src.iteration_completed, Events.ITERATION_COMPLETED) super().detach(engine, usage)

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 04/17/2024, 8:16:17 PM.

Built with Sphinx using a theme provided by Read the Docs.