Source code for ignite.metrics.epoch_metric
import warnings
from typing import Callable, Sequence
import torch
from ignite.metrics.metric import Metric
__all__ = ["EpochMetric"]
[docs]class EpochMetric(Metric):
"""Class for metrics that should be computed on the entire output history of a model.
Model's output and targets are restricted to be of shape ``(batch_size, n_classes)``. Output
datatype should be `float32`. Target datatype should be `long`.
.. warning::
Current implementation stores all input data (output and target) in as tensors before computing a metric.
This can potentially lead to a memory error if the input data is larger than available RAM.
.. warning::
Current implementation does not work with distributed computations. Results are not gather across all devices
and computed results are valid for a single device only.
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
If target shape is ``(batch_size, n_classes)`` and ``n_classes > 1`` than it should be binary:
e.g. ``[[0, 1, 0, 1], ]``.
Args:
compute_fn (callable): a callable with the signature (`torch.tensor`, `torch.tensor`) takes as the input
`predictions` and `targets` and returns a scalar.
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): if True, ``compute_fn`` is run on the first batch of data to ensure there are no
issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``.
Warnings:
EpochMetricWarning: User is warned that there are issues with compute_fn on a batch of data processed.
"""
def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x, check_compute_fn: bool = True):
if not callable(compute_fn):
raise TypeError("Argument compute_fn should be callable.")
self._predictions = None
self._targets = None
self.compute_fn = compute_fn
self._check_compute_fn = check_compute_fn
super(EpochMetric, self).__init__(output_transform=output_transform, device="cpu")
def reset(self) -> None:
self._predictions = []
self._targets = []
def _check_shape(self, output):
y_pred, y = output
if y_pred.ndimension() not in (1, 2):
raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).")
if y.ndimension() not in (1, 2):
raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).")
if y.ndimension() == 2:
if not torch.equal(y ** 2, y):
raise ValueError("Targets should be binary (0 or 1).")
def _check_type(self, output):
y_pred, y = output
if len(self._predictions) < 1:
return
dtype_preds = self._predictions[-1].type()
if dtype_preds != y_pred.type():
raise ValueError(
"Incoherent types between input y_pred and stored predictions: "
"{} vs {}".format(dtype_preds, y_pred.type())
)
dtype_targets = self._targets[-1].type()
if dtype_targets != y.type():
raise ValueError(
"Incoherent types between input y and stored targets: " "{} vs {}".format(dtype_targets, y.type())
)
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
y_pred, y = output
if y_pred.ndimension() == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
if y.ndimension() == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)
y_pred = y_pred.cpu().clone()
y = y.cpu().clone()
self._check_type((y_pred, y))
self._predictions.append(y_pred)
self._targets.append(y)
# Check once the signature and execution of compute_fn
if len(self._predictions) == 1 and self._check_compute_fn:
try:
self.compute_fn(self._predictions[0], self._targets[0])
except Exception as e:
warnings.warn("Probably, there can be a problem with `compute_fn`:\n {}.".format(e), EpochMetricWarning)
def compute(self) -> None:
_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)
return self.compute_fn(_prediction_tensor, _target_tensor)
class EpochMetricWarning(UserWarning):
pass