
Source code for torcheval.metrics.functional.ranking.num_collisions

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

[docs]@torch.inference_mode() def num_collisions(input: torch.Tensor) -> torch.Tensor: """ Compute the number of collisions given a list of input(ids). Args: input (Tensor): a tensor of input ids (num_samples, ). class probabilities of shape (num_samples, num_classes). Examples:: >>> import torch >>> from torcheval.metrics.functional import num_collisions >>> input = torch.tensor([3, 4, 2, 3]) >>> num_collisions(input) tensor([1, 0, 0, 1]) >>> input = torch.tensor([3, 4, 1, 3, 1, 1, 5]) >>> num_collisions(input) tensor([1, 0, 2, 1, 2, 2, 0]) """ _num_collisions_input_check(input) input_for_logits = input.view(1, -1).repeat_interleave(torch.numel(input), dim=0) num_collisions = (input_for_logits == input.view(-1, 1)).sum( dim=1, keepdim=True ) - 1 return num_collisions.view(-1)
def _num_collisions_input_check(input: torch.Tensor) -> None: if input.ndim != 1: raise ValueError( f"input should be a one-dimensional tensor, got shape {input.shape}." ) if input.dtype not in (, torch.int8, torch.int16, torch.int32, torch.int64, ): raise ValueError(f"input should be an integer tensor, got {input.dtype}.")


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources