Shortcuts

Source code for ignite.metrics.epoch_metric

import warnings
from typing import Callable, Sequence

import torch

from ignite.metrics.metric import Metric

__all__ = ["EpochMetric"]


[docs]class EpochMetric(Metric): """Class for metrics that should be computed on the entire output history of a model. Model's output and targets are restricted to be of shape ``(batch_size, n_classes)``. Output datatype should be `float32`. Target datatype should be `long`. .. warning:: Current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM. .. warning:: Current implementation does not work with distributed computations. Results are not gather across all devices and computed results are valid for a single device only. - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. If target shape is ``(batch_size, n_classes)`` and ``n_classes > 1`` than it should be binary: e.g. ``[[0, 1, 0, 1], ]``. Args: compute_fn (callable): a callable with the signature (`torch.tensor`, `torch.tensor`) takes as the input `predictions` and `targets` and returns a scalar. output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. check_compute_fn (bool): if True, ``compute_fn`` is run on the first batch of data to ensure there are no issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. Warnings: EpochMetricWarning: User is warned that there are issues with compute_fn on a batch of data processed. """ def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x, check_compute_fn: bool = True): if not callable(compute_fn): raise TypeError("Argument compute_fn should be callable.") self._predictions = None self._targets = None self.compute_fn = compute_fn self._check_compute_fn = check_compute_fn super(EpochMetric, self).__init__(output_transform=output_transform, device="cpu") def reset(self) -> None: self._predictions = [] self._targets = [] def _check_shape(self, output): y_pred, y = output if y_pred.ndimension() not in (1, 2): raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).") if y.ndimension() not in (1, 2): raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).") if y.ndimension() == 2: if not torch.equal(y ** 2, y): raise ValueError("Targets should be binary (0 or 1).") def _check_type(self, output): y_pred, y = output if len(self._predictions) < 1: return dtype_preds = self._predictions[-1].type() if dtype_preds != y_pred.type(): raise ValueError( "Incoherent types between input y_pred and stored predictions: " "{} vs {}".format(dtype_preds, y_pred.type()) ) dtype_targets = self._targets[-1].type() if dtype_targets != y.type(): raise ValueError( "Incoherent types between input y and stored targets: " "{} vs {}".format(dtype_targets, y.type()) ) def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) y_pred, y = output if y_pred.ndimension() == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) if y.ndimension() == 2 and y.shape[1] == 1: y = y.squeeze(dim=-1) y_pred = y_pred.cpu().clone() y = y.cpu().clone() self._check_type((y_pred, y)) self._predictions.append(y_pred) self._targets.append(y) # Check once the signature and execution of compute_fn if len(self._predictions) == 1 and self._check_compute_fn: try: self.compute_fn(self._predictions[0], self._targets[0]) except Exception as e: warnings.warn("Probably, there can be a problem with `compute_fn`:\n {}.".format(e), EpochMetricWarning) def compute(self) -> None: _prediction_tensor = torch.cat(self._predictions, dim=0) _target_tensor = torch.cat(self._targets, dim=0) return self.compute_fn(_prediction_tensor, _target_tensor)
class EpochMetricWarning(UserWarning): pass

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 10/12/2024, 1:58:31 PM.

Built with Sphinx using a theme provided by Read the Docs.