Shortcuts

Source code for ignite.metrics.metric_group

from typing import Any, Callable, Dict, Sequence, Tuple

import torch

from ignite.metrics import Metric


[docs]class MetricGroup(Metric): """ A class for grouping metrics so that user could manage them easier. Args: metrics: a dictionary of names to metric instances. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. `output_transform` of each metric in the group is also called upon its update. skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle this. Examples: We construct a group of metrics, attach them to the engine at once and retrieve their result. .. code-block:: python import torch metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())}) metric_group.attach(default_evaluator, "eval_metrics") y_true = torch.tensor([1, 0, 1, 1, 0, 1]) y_pred = torch.tensor([1, 0, 1, 0, 1, 1]) state = default_evaluator.run([[y_pred, y_true]]) # Metrics individually available in `state.metrics` state.metrics["acc"], state.metrics["precision"], state.metrics["loss"] # And also altogether state.metrics["eval_metrics"] .. versionchanged:: 0.5.2 ``skip_unrolling`` argument is added. """ _state_dict_all_req_keys: Tuple[str, ...] = ("metrics",) def __init__( self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False ): self.metrics = metrics super(MetricGroup, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling)
[docs] def reset(self) -> None: for m in self.metrics.values(): m.reset()
[docs] def update(self, output: Sequence[torch.Tensor]) -> None: for m in self.metrics.values(): m.update(m._output_transform(output))
[docs] def compute(self) -> Dict[str, Any]: return {k: m.compute() for k, m in self.metrics.items()}

© Copyright 2025, PyTorch-Ignite Contributors. Last updated on 04/28/2025, 10:00:46 AM.

Built with Sphinx using a theme provided by Read the Docs.