Source code for ignite.metrics.epoch_metric

import warnings

import torch

from ignite.metrics.metric import Metric

[docs]class EpochMetric(Metric): """Class for metrics that should be computed on the entire output history of a model. Model's output and targets are restricted to be of shape `(batch_size, n_classes)`. Output datatype should be `float32`. Target datatype should be `long`. .. warning:: Current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM. .. warning:: Current implementation does not work with distributed computations. Results are not gather across all devices and computed results are valid for a single device only. - `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. If target shape is `(batch_size, n_classes)` and `n_classes > 1` than it should be binary: e.g. `[[0, 1, 0, 1], ]`. Args: compute_fn (callable): a callable with the signature (`torch.tensor`, `torch.tensor`) takes as the input `predictions` and `targets` and returns a scalar. output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.Engine`'s `process_function`'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. """ def __init__(self, compute_fn, output_transform=lambda x: x): if not callable(compute_fn): raise TypeError("Argument compute_fn should be callable.") super(EpochMetric, self).__init__(output_transform=output_transform, device='cpu') self.compute_fn = compute_fn def reset(self): self._predictions = torch.tensor([], dtype=torch.float32) self._targets = torch.tensor([], dtype=torch.long) def update(self, output): y_pred, y = output if y_pred.ndimension() not in (1, 2): raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).") if y.ndimension() not in (1, 2): raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).") if y.ndimension() == 2: if not torch.equal(y ** 2, y): raise ValueError("Targets should be binary (0 or 1).") if y_pred.ndimension() == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) if y.ndimension() == 2 and y.shape[1] == 1: y = y.squeeze(dim=-1) y_pred = y = self._predictions =[self._predictions, y_pred], dim=0) self._targets =[self._targets, y], dim=0) # Check once the signature and execution of compute_fn if self._predictions.shape == y_pred.shape: try: self.compute_fn(self._predictions, self._targets) except Exception as e: warnings.warn("Probably, there can be a problem with `compute_fn`:\n {}.".format(e), RuntimeWarning) def compute(self): return self.compute_fn(self._predictions, self._targets)

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 01/11/2024, 12:19:46 PM.

Built with Sphinx using a theme provided by Read the Docs.