Shortcuts

Source code for ignite.metrics.metric

import numbers
from abc import ABCMeta, abstractmethod
from functools import wraps
from collections.abc import Mapping
import warnings

import torch
import torch.distributed as dist

from ignite.engine import Events


[docs]class Metric(metaclass=ABCMeta): """ Base class for all Metrics. Args: output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.Engine`'s `process_function`'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. device (str of torch.device, optional): device specification in case of distributed computation usage. In most of the cases, it can be defined as "cuda:local_rank" or "cuda" if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is initialized and available, device is set to `cuda`. """ _required_output_keys = ("y_pred", "y") def __init__(self, output_transform=lambda x: x, device=None): self._output_transform = output_transform # Check device if distributed is initialized: if dist.is_available() and dist.is_initialized(): # check if reset and update methods are decorated. Compute may not be decorated if not (hasattr(self.reset, "_decorated") and hasattr(self.update, "_decorated")): warnings.warn("{} class does not support distributed setting. Computed result is not collected " "across all computing devices".format(self.__class__.__name__), RuntimeWarning) if device is None: device = "cuda" device = torch.device(device) self._device = device self._is_reduced = False self.reset()
[docs] @abstractmethod def reset(self): """ Resets the metric to it's initial state. This is called at the start of each epoch. """ pass
[docs] @abstractmethod def update(self, output): """ Updates the metric's state using the passed batch output. This is called once for each batch. Args: output: the is the output from the engine's process function. """ pass
[docs] @abstractmethod def compute(self): """ Computes the metric based on it's accumulated state. This is called at the end of each epoch. Returns: Any: the actual quantity of interest. Raises: NotComputableError: raised when the metric cannot be computed. """ pass
def _sync_all_reduce(self, tensor): if not (dist.is_available() and dist.is_initialized()): # Nothing to reduce return tensor tensor_to_number = False if isinstance(tensor, numbers.Number): tensor = torch.tensor(tensor, device=self._device) tensor_to_number = True if isinstance(tensor, torch.Tensor): # check if the tensor is at specified device if tensor.device != self._device: tensor = tensor.to(self._device) else: raise TypeError("Unhandled input type {}".format(type(tensor))) # synchronize and reduce dist.barrier() dist.all_reduce(tensor) if tensor_to_number: return tensor.item() return tensor def started(self, engine): self.reset() @torch.no_grad() def iteration_completed(self, engine): output = self._output_transform(engine.state.output) if isinstance(output, Mapping): if self._required_output_keys is None: raise TypeError("Transformed engine output for {} metric should be a tuple/list, but given {}" .format(self.__class__.__name__, type(output))) if not all([k in output for k in self._required_output_keys]): raise ValueError("When transformed engine's output is a mapping, " "it should contain {} keys, but given {}".format(self._required_output_keys, list(output.keys()))) output = tuple(output[k] for k in self._required_output_keys) self.update(output) def completed(self, engine, name): result = self.compute() if torch.is_tensor(result) and len(result.shape) == 0: result = result.item() engine.state.metrics[name] = result def attach(self, engine, name): engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, self.started) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) def __add__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x + y, self, other) def __radd__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x + y, other, self) def __sub__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x - y, self, other) def __rsub__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x - y, other, self) def __mul__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x * y, self, other) def __rmul__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x * y, other, self) def __pow__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x ** y, self, other) def __rpow__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x ** y, other, self) def __mod__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x % y, self, other) def __div__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), self, other) def __rdiv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), other, self) def __truediv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), self, other) def __rtruediv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), other, self) def __floordiv__(self, other): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x // y, self, other) def __getattr__(self, attr): from ignite.metrics import MetricsLambda def fn(x, *args, **kwargs): return getattr(x, attr)(*args, **kwargs) def wrapper(*args, **kwargs): return MetricsLambda(fn, self, *args, **kwargs) return wrapper def __getitem__(self, index): from ignite.metrics import MetricsLambda return MetricsLambda(lambda x: x[index], self)
def sync_all_reduce(*attrs): def wrapper(func): @wraps(func) def another_wrapper(self, *args, **kwargs): if not isinstance(self, Metric): raise RuntimeError("Decorator sync_all_reduce should be used on " "ignite.metric.Metric class methods only") if len(attrs) > 0 and not self._is_reduced: for attr in attrs: t = getattr(self, attr, None) if t is not None: t = self._sync_all_reduce(t) self._is_reduced = True setattr(self, attr, t) return func(self, *args, **kwargs) return another_wrapper wrapper._decorated = True return wrapper def reinit__is_reduced(func): @wraps(func) def wrapper(self, *args, **kwargs): func(self, *args, **kwargs) self._is_reduced = False wrapper._decorated = True return wrapper

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 04/17/2024, 8:17:28 PM.

Built with Sphinx using a theme provided by Read the Docs.