Shortcuts

Source code for torcheval.metrics.aggregation.max

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, Optional, TypeVar

import torch

from torcheval.metrics.metric import Metric


TMax = TypeVar("TMax")


[docs]class Max(Metric[torch.Tensor]): """ Calculate the maximum value of all elements in all the input tensors. Its functional version is ``torch.max(input)``. Examples:: >>> import torch >>> from torcheval.metrics import Max >>> metric = Max() >>> metric.update(torch.tensor([[1, 2], [3, 4]])) >>> metric.compute() tensor(4.) >>> metric.update(torch.tensor(-1)).compute() tensor(4.) >>> metric.reset() >>> metric.update(torch.tensor(-1)).compute() tensor(-1.) """
[docs] def __init__( self: TMax, *, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) self._add_state("max", torch.tensor(float("-inf"), device=self.device))
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update(self: TMax, input: torch.Tensor) -> TMax: self.max = torch.max(self.max, torch.max(input)) return self @torch.inference_mode() def compute(self: TMax) -> torch.Tensor: return self.max @torch.inference_mode() def merge_state(self: TMax, metrics: Iterable[TMax]) -> TMax: for metric in metrics: self.max = torch.max(self.max, metric.max.to(self.device)) return self

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources