Shortcuts

# Source code for ignite.metrics.mean_pairwise_distance

from __future__ import division

import torch
from torch.nn.functional import pairwise_distance

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

[docs]class MeanPairwiseDistance(Metric):
"""
Calculates the mean pairwise distance: average of pairwise distances computed on provided batches.

- update must receive output of the form (y_pred, y) or {'y_pred': y_pred, 'y': y}.
"""
def __init__(self, p=2, eps=1e-6, output_transform=lambda x: x, device=None):
super(MeanPairwiseDistance, self).__init__(output_transform, device=device)
self._p = p
self._eps = eps

@reinit__is_reduced
def reset(self):
self._sum_of_distances = 0.0
self._num_examples = 0

@reinit__is_reduced
def update(self, output):
y_pred, y = output
distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps)
self._sum_of_distances += torch.sum(distances).item()
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_distances", "_num_examples")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('MeanAbsoluteError must have at least one example before it can be computed.')
return self._sum_of_distances / self._num_examples


 © Copyright 2024, PyTorch-Ignite Contributors. Last updated on 07/01/2024, 6:46:53 PM. Built with Sphinx using a theme provided by Read the Docs.