Shortcuts

Source code for torcheval.metrics.aggregation.mean

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

import logging
from typing import Iterable, Optional, TypeVar, Union

import torch

from torcheval.metrics.functional.aggregation.mean import _mean_update
from torcheval.metrics.metric import Metric

TMean = TypeVar("TMean")


[docs]class Mean(Metric[torch.Tensor]): """ Calculate the weighted mean value of all elements in all the input tensors. When weight is not provided, it calculates the unweighted mean. Its functional version is ``torcheval.functional.mean()``. Examples:: >>> import torch >>> from torcheval.metrics import Mean >>> metric = Mean() >>> metric.update(1) >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor(2.) >>> metric.update(torch.tensor(-1)).compute() tensor(1.25) >>> metric.reset() >>> metric.update(torch.tensor(-1)).compute() tensor(-1.) >>> metric = Mean() >>> metric.update(torch.tensor([2, 3]), torch.tensor([0.2, 0.8])).compute() tensor(2.8) >>> metric.update(torch.tensor([4, 5]), 0.5).compute() tensor(3.65) >>> metric.update(torch.tensor([6]), 2).compute() tensor(4.825) """
[docs] def __init__( self: TMean, *, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) self._add_state( "weighted_sum", torch.tensor(0.0, device=self.device, dtype=torch.float64) ) self._add_state( "weights", torch.tensor(0.0, device=self.device, dtype=torch.float64) )
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TMean, input: torch.Tensor, *, weight: Union[float, int, torch.Tensor] = 1.0, ) -> TMean: """ Compute weighted mean. When weight is not provided, it calculates the unweighted mean. weighted_mean = sum(weight * input) / sum(weight) Args: input (Tensor): Tensor of input values. weight(optional): Float or Int or Tensor of input weights. It is default to 1.0. If weight is a Tensor, its size should match the input tensor size. Raises: ValueError: If value of weight is neither a ``float`` nor a ``int'' nor a ``torch.Tensor`` that matches the input tensor size. """ weighted_sum, weights = _mean_update(input, weight) self.weighted_sum += weighted_sum self.weights += weights return self @torch.inference_mode() def compute(self: TMean) -> torch.Tensor: """ If no calls to ``update()`` are made before ``compute()`` is called, the function throws a warning and returns 0.0. """ if not self.weighted_sum: logging.warning("No calls to update() have been made - returning 0.0") return torch.tensor(0.0, dtype=torch.float64) return self.weighted_sum / self.weights @torch.inference_mode() def merge_state(self: TMean, metrics: Iterable[TMean]) -> TMean: for metric in metrics: self.weighted_sum += metric.weighted_sum.to(self.device) self.weights += metric.weights.to(self.device) return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources