Shortcuts

Source code for ignite.metrics.metric

from abc import ABCMeta, abstractmethod
from ignite._six import with_metaclass
from ignite.engine import Events
import torch


[docs]class Metric(with_metaclass(ABCMeta, object)): """ Base class for all Metrics. Args: output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.Engine`'s `process_function`'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. """ def __init__(self, output_transform=lambda x: x): self._output_transform = output_transform self.reset()
[docs] @abstractmethod def reset(self): """ Resets the metric to it's initial state. This is called at the start of each epoch. """ pass
[docs] @abstractmethod def update(self, output): """ Updates the metric's state using the passed batch output. This is called once for each batch. Args: output: the is the output from the engine's process function. """ pass
[docs] @abstractmethod def compute(self): """ Computes the metric based on it's accumulated state. This is called at the end of each epoch. Returns: Any: the actual quantity of interest. Raises: NotComputableError: raised when the metric cannot be computed. """ pass
def started(self, engine): self.reset() @torch.no_grad() def iteration_completed(self, engine): output = self._output_transform(engine.state.output) self.update(output) def completed(self, engine, name): result = self.compute() if torch.is_tensor(result) and len(result.shape) == 0: result = result.item() engine.state.metrics[name] = result def attach(self, engine, name): engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, self.started) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) def __add__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x + y, self, other) def __radd__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x + y, other, self) def __sub__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x - y, self, other) def __rsub__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x - y, other, self) def __mul__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x * y, self, other) def __rmul__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x * y, other, self) def __pow__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x ** y, self, other) def __rpow__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x ** y, other, self) def __mod__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x % y, self, other) def __div__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), self, other) def __rdiv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), other, self) def __truediv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), self, other) def __rtruediv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), other, self) def __floordiv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x // y, self, other) def __getattr__(self, attr): from ignite.metrics import MetricsLambda def fn(x, *args, **kwargs): return getattr(x, attr)(*args, **kwargs) def wrapper(*args, **kwargs): return MetricsLambda(fn, self, *args, **kwargs) return wrapper def __getitem__(self, index): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x: x[index], self)

© Copyright 2022, PyTorch-Ignite Contributors. Last updated on 08/16/2022, 6:34:00 AM.

Built with Sphinx using a theme provided by Read the Docs.