Shortcuts

Source code for torcheval.metrics.toolkit

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from copy import deepcopy
from typing import (
    Any,
    Dict,
    Iterable,
    List,
    MutableMapping,
    Optional,
    overload,
    TypeVar,
    Union,
)

import torch
import torch.distributed as dist
from pyre_extensions import none_throws

from torcheval.metrics import Metric
from torcheval.metrics.metric import TComputeReturn, TState
from torcheval.metrics.synclib import metrics_traversal_order, sync_states

log: logging.Logger = logging.getLogger(__name__)

_TMetrics = TypeVar("_TMetrics", bound=Iterable[Metric])

_TMP: str = "tmp"


def _get_world_size(process_group: Optional[dist.ProcessGroup]) -> int:
    if not dist.is_available() or not dist.is_initialized():
        # dist is not initialized or available, return 1 for world size
        return 1
    return dist.get_world_size(group=process_group)


def _get_rank(process_group: Optional[dist.ProcessGroup]) -> int:
    if not dist.is_available() or not dist.is_initialized():
        # dist is not initialized or available, return 0 for rank
        return 0
    return dist.get_rank(group=process_group)


[docs]def sync_and_compute( metric: Metric[TComputeReturn], process_group: Optional[dist.ProcessGroup] = None, ) -> TComputeReturn: """ Sync metric states and returns the ``metric.compute()`` result of synced metric on all ranks. Args: metric: The metric object to be synced and computed. process_group: The process group on which the metric states are gathered. default: ``None`` (the entire world) Examples:: >>> # Assumes world_size of 3. >>> # Process group initialization omitted on each rank. >>> import torch >>> import torch.distributed as dist >>> from torcheval.metrics import Max >>> max = Max() >>> max.update(torch.tensor(dist.get_rank())).compute() tensor(0.) # Rank 0 tensor(1.) # Rank 1 tensor(2.) # Rank 2 >>> sync_and_compute(max) tensor(2.) # Rank 0 tensor(2.) # Rank 1 tensor(2.) # Rank 2 """ synced_metric = get_synced_metric(metric, process_group) compute_result = synced_metric.compute() return compute_result
[docs]def sync_and_compute_collection( metrics: MutableMapping[str, Metric], process_group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: """ Sync metric states across a dict of metrics and returns the ``metric.compute()`` result of synced metrics on all ranks Args: metrics: The dict of metric objects to be synced and computed. process_group: The process group on which the metric states are gathered. default: ``None`` (the entire world) Examples:: >>> # Assumes world_size of 3. >>> # Process group initialization omitted on each rank. >>> import torch >>> import torch.distributed as dist >>> from torcheval.metrics import Max, Min >>> metrics = {"max" : Max(), "min": Min()} >>> metrics["max"].update(torch.tensor(dist.get_rank())).compute() tensor(0.) # Rank 0 tensor(1.) # Rank 1 tensor(2.) # Rank 2 >>> metrics["min"].update(torch.tensor(dist.get_rank())).compute() tensor(0.) # Rank 0 tensor(1.) # Rank 1 tensor(2.) # Rank 2 >>> sync_and_compute_collection(metrics) {"max" : tensor(2.), "min": tensor(0.)} # Rank 0 {"max" : tensor(2.), "min": tensor(0.)} # Rank 1 {"max" : tensor(2.), "min": tensor(0.)} # Rank 2 """ synced_metrics = get_synced_metric_collection(metrics, process_group) compute_result = {key: m.compute() for key, m in synced_metrics.items()} return compute_result
[docs]def get_synced_state_dict( metric: Metric, process_group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: """ Return the state dict of a metric after syncing on all ranks. Return an empty dict on other ranks. Args: metric: The metric object to sync and get ``state_dict()`` process_group: The process group on which the metric states are gathered. default: ``None`` (the entire world) Returns: state dict of synced metric Examples:: >>> # Assumes world_size of 3. >>> # Process group initialization omitted on each rank. >>> import torch >>> import torch.distributed as dist >>> from torcheval import Max >>> max = Max() >>> max.update(torch.tensor(dist.get_rank())) >>> get_synced_state_dict(max) {"max", tensor(2.)} # Rank 0 {"max", tensor(2.)} # Rank 1 {"max", tensor(2.)} # Rank 2 """ synced_metric = get_synced_metric(metric, process_group) return synced_metric.state_dict() if synced_metric else {}
[docs]def get_synced_state_dict_collection( metric_collection: MutableMapping[str, Metric], process_group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Dict[str, Any]]: """ Return the state dict of a collection of metrics after syncing on all ranks. Return an None on other ranks. Args: metric_collection (Dict[str, Metric]): The metric objects to sync and get ``state_dict()`` process_group: The process group on which the metric states are gathered. default: ``None`` (the entire world) Returns: Bundle of state dicts of for the synced metrics Examples:: >>> # Assumes world_size of 3. >>> # Process group initialization omitted on each rank. >>> import torch >>> import torch.distributed as dist >>> from torcheval import Max, Min >>> maximum = Max() >>> maximum.update(torch.tensor(dist.get_rank())) >>> minimum = Min() >>> minimum.update(torch.tensor(dist.get_rank())) >>> get_synced_state_dict({"max rank": maximum, "min rank": minimum}) {"max rank": {"max", tensor(2.)}, "min rank": {"min", tensor(0.)}} # Rank 0 {"max rank": {"max", tensor(2.)}, "min rank": {"min", tensor(0.)}} # Rank 1 {"max rank": {"max", tensor(2.)}, "min rank": {"min", tensor(0.)}} # Rank 2 """ synced_metrics = get_synced_metric_collection( metric_collection, process_group, ) return {key: metric.state_dict() for key, metric in synced_metrics.items()}
[docs]def clone_metric(metric: Metric) -> Metric: """ Return a new metric instance which is cloned from the input metric. Args: metric: The metric object to clone Returns: A new metric instance from cloning """ return deepcopy(metric)
[docs]def clone_metrics(metrics: _TMetrics) -> List[Metric]: """ Return a list of new metric instances which are cloned from the input metrics. Args: metrics: The metric objects to clone Returns: A list of metric instances from cloning """ return [clone_metric(metric) for metric in metrics]
[docs]def get_synced_metric( metric: Metric, process_group: Optional[dist.ProcessGroup] = None, ) -> Metric: """ Returns a metric object on all ranks whose internal state variables are synced across processes in the process_group. Args: metric: The metric object to sync. process_group: The process group on which the metric states are gathered. default: ``None`` (the entire world) Examples:: >>> # Assumes world_size of 3. >>> # Process group initialization omitted on each rank. >>> import torch >>> import torch.distributed as dist >>> from torcheval import Max >>> max = Max() >>> max.update(torch.tensor(dist.get_rank())).compute() tensor(0.) # Rank 0 tensor(1.) # Rank 1 tensor(2.) # Rank 2 >>> synced_metric = get_synced_metric(max) # by default sync metric states to Rank 0 >>> synced_metric.compute() if synced_metric else None tensor(2.) # Rank 0 None # Rank 1 -- synced_metric is None None # Rank 2 -- synced_metric is None >>> synced_metric = get_synced_metric(max) >>> synced_metric.compute() tensor(2.) # Rank 0 tensor(2.) # Rank 1 tensor(2.) # Rank 2 """ world_size = _get_world_size(process_group) _validate_rank_and_world_size(world_size) if world_size == 1: return metric gathered_metric_list = _sync_metric_object( metric, # pyre-fixme[6]: For 2nd param expected `ProcessGroup` but got `Union[None, # dist.ProcessGroup, _distributed_c10d.ProcessGroup]`. process_group if process_group else dist.group.WORLD, world_size, ) local_rank = _get_rank(process_group) other_rank_metrics: List[Metric] = [ gathered_metric_list[rank] for rank in range(world_size) if rank != local_rank ] return clone_metric(metric).to(metric.device).merge_state(other_rank_metrics)
[docs]def get_synced_metric_collection( metric_collection: MutableMapping[str, Metric], process_group: Optional[dist.ProcessGroup] = None, ) -> Union[Dict[str, Metric], MutableMapping[str, Metric]]: """ Returns a dict of metric objects to all ranks whose internal state variables are synced across processes in the process_group. The data transfer is batched to maximize efficiency. Args: metric_collection (Dict[str, Metric]): The dict of metric objects to sync. process_group (int): The process group on which the metric states are gathered. default: ``None`` (the entire world) Examples:: >>> # Assumes world_size of 3. >>> # Process group initialization omitted on each rank. >>> import torch >>> import torch.distributed as dist >>> from torcheval.metrics import Max, Min >>> metrics = {"max" : Max(), "min": Min()} >>> metrics["max"].update(torch.tensor(dist.get_rank())) >>> metrics["min"].update(torch.tensor(dist.get_rank())) >>> synced_metrics = get_synced_metric_collection(metrics) >>> synced_metrics["max"].compute() tensor(2.) # Rank 0 tensor(2.) # Rank 1 tensor(2.) # Rank 2 >>> synced_metrics["min"].compute() tensor(0.) # Rank 0 tensor(0.) # Rank 1 tensor(0.) # Rank 2 """ world_size = _get_world_size(process_group) _validate_rank_and_world_size(world_size) if world_size == 1: return metric_collection list_of_metric_collections = _sync_metric_object( metric_collection, # pyre-fixme[6]: For 2nd param expected `ProcessGroup` but got `Union[None, # dist.ProcessGroup, _distributed_c10d.ProcessGroup]`. process_group if process_group else dist.group.WORLD, world_size, ) if isinstance(list_of_metric_collections[0], MutableMapping): local_rank = dist.get_rank(process_group) # metric bundles are dicts. synced_metric_dict: Dict[str, Metric] = {} for metric_key in metric_collection.keys(): base_metric = metric_collection[metric_key] base_metric = clone_metric(base_metric).to(base_metric.device) other_rank_metrics: List[Metric] = [ list_of_metric_collections[rank][metric_key] for rank in range(world_size) if rank != local_rank ] synced_metric_dict[metric_key] = base_metric.merge_state(other_rank_metrics) return synced_metric_dict
def _validate_rank_and_world_size( world_size: int, ) -> None: if world_size == 1: log.warning( "World size is 1, and metric(s) not synced. " "returning the input metric(s)." ) elif world_size == -1: raise RuntimeError("The current process is not part of the process group") if world_size < 1: raise RuntimeError( f"Unexpected world_size {world_size} is seen when syncing metrics!" ) @overload def _sync_metric_object( local_metric_data: Metric, process_group: dist.ProcessGroup, world_size: int, ) -> List[Metric]: ... @overload def _sync_metric_object( local_metric_data: MutableMapping[str, Metric], process_group: dist.ProcessGroup, world_size: int, ) -> List[MutableMapping[str, Metric]]: ... def _apply_device_to_tensor_states( state_dict: Dict[str, Any], device: torch.device ) -> None: for state_name, state_value in state_dict.items(): if isinstance(state_value, torch.Tensor): state_dict[state_name] = state_value.to(device) elif isinstance(state_value, list): state_dict[state_name] = [tensor.to(device) for tensor in state_value] elif isinstance(state_value, dict): state_dict[state_name] = { key: tensor.to(device) for key, tensor in state_value.items() } def _sync_metric_object( local_metric_data: Union[Metric, MutableMapping[str, Metric]], process_group: dist.ProcessGroup, world_size: int, ) -> Union[List[Metric], List[MutableMapping[str, Metric]]]: unpack = False # unpack the dictionary into a single metric when returned. Only used when metric_data is a metric and not a dict of metrics. if isinstance(local_metric_data, Metric): local_metric_data = {_TMP: local_metric_data} unpack = True # Allow metrics to run some pre-processing before syncing states # for example, a common optimization is to concat tensors in a list into a single tensor for m in local_metric_data.values(): m._prepare_for_merge_state() # create a dict of state dicts for each metric in the collection, i.e. extract the state dicts from the Metric objects metric_state_data: Dict[str, Dict[str, TState]] = {} metric_to_device: Dict[str, torch.device] = {} backend = dist.get_backend(group=process_group) for metric_name, metric in local_metric_data.items(): metric_state_data[metric_name] = metric.state_dict() if backend == "nccl" and metric.device.type == "cpu": log.warning( "Metric tensor states are on CPU, but NCCL process group detected. " "These tensors will be moved to GPU prior to syncing. " "It is recommended for efficiency reasons to either use a gloo process group to sync metrics on CPU, or move the metrics to GPU and use the NCCL process group." ) # move tensors to gpu since on nccl _apply_device_to_tensor_states( metric_state_data[metric_name], # pyre-ignore: Incompatible parameter type [6] torch.cuda.current_device(), ) # pyre-ignore: Incompatible parameter type [6] metric_to_device[metric_name] = torch.cuda.current_device() else: metric_to_device[metric_name] = metric.device metric_state_traversal_order = metrics_traversal_order(metric_state_data) world_metric_data = sync_states( metric_state_data, metric_to_device, metric_state_traversal_order, process_group=process_group, ) world_metric_data = none_throws(world_metric_data) # Repack states into Metrics or Dict[str, Metric]s if unpack: # if users passed in one metric, read it from the "tmp" key and return a list of metrics. gathered_data_list = [] for rank_data in world_metric_data: gathered_data_list.append(_convert_to_psuedo_metric(rank_data[_TMP])) else: # if users passed in a dict[str, metric], return a list of dict[str, metric] populated with the gathered state dicts. gathered_data_list = [] for rank_data in world_metric_data: rank_dict = {} for metric_name in local_metric_data.keys(): rank_dict[metric_name] = _convert_to_psuedo_metric( rank_data[metric_name] ) gathered_data_list.append(rank_dict) return gathered_data_list # pyre-ignore: Missing return annotation [3] def _convert_to_psuedo_metric(metric_state_dict: Dict[str, Any]) -> Any: """ Converts dictionary to object with attributes set according to key-value. """ return type("", (), metric_state_dict)
[docs]def reset_metrics(metrics: _TMetrics) -> _TMetrics: """ Reset input metrics and returns the reset collection back to users. Args: metrics: The metrics to be reset Examples:: >>> from torcheval.metrics import Max, Min >>> max = Max() >>> min = Min() >>> max.update(torch.tensor(1)).compute() >>> min.update(torch.tensor(2)).compute() >>> max, min = reset_metrics((max, min)) >>> max.compute() tensor(0.) >>> min.compute() tensor(0.) """ for metric in metrics: metric.reset() return metrics
[docs]def to_device( metrics: _TMetrics, device: torch.device, *args: Any, **kwargs: Any ) -> _TMetrics: """ Moves input metrics to the target device and returns the moved metrics back to users. Args: metrics: The metrics to be moved to the device device: the device to move te metrics to *args: Variadic arguments forwarded to ``Metric.to`` **kwargs: Named arguments forwarded to ``Metric.to`` Examples:: >>> from torcheval.metrics import Max, Min >>> max = Max() >>> min = Min() >>> max, min = to_device((max, min), torch.device("cuda")) >>> max.device torch.device("cuda") >>> min.device torch.device("cuda") """ for metric in metrics: metric.to(device, *args, **kwargs) return metrics
[docs]def classwise_converter( input: torch.Tensor, name: str, labels: Optional[List[str]] = None ) -> Dict[str, torch.Tensor]: """ Converts an unaveraged metric result tensor into a dictionary with each key being 'metricname_classlabel' and value being the data associated with that class. Args: input (torch.Tensor): The tensor to be split along its first dimension. name (str): Name of the metric. labels (List[str], Optional): Optional list of strings indicating the different classes. Raises: ValueError: When the length of `labels` is not equal to the number of classes. """ if labels is None: return {f"{name}_{i}": val for i, val in enumerate(input)} if input.size(dim=0) != len(labels): raise ValueError( f"Number of labels {len(labels)} must be equal to the number of classes {input.size(dim=0)}!" ) return {f"{name}_{label}": val for label, val in zip(labels, input)}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources