Shortcuts

Source code for torcheval.metrics.functional.text.word_information_lost

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Tuple, Union

import torch

from torcheval.metrics.functional.text.helper import _get_errors_and_totals


def _wil_update(
    input: Union[str, List[str]],
    target: Union[str, List[str]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Update the wil score with the current set of references and predictions.
    Args:
        input: Transcription(s) to score as a string or list of strings
        target: Reference(s) for each speech input as a string or list of strings
    Returns:
        Number of correct words
        Number of words overall references
        Number of words overall predictions
    """
    if isinstance(input, str):
        input = [input]
    if isinstance(target, str):
        target = [target]
    assert len(input) == len(
        target
        # pyre-fixme[6]: For 1st argument expected `Sized` but got `(object=...) -> str`.
    ), f"Arguments must contain the same number of strings, but got len(input)={len(input)} and len(target)={len(target)}"
    errors, max_total, target_total, input_total = _get_errors_and_totals(input, target)
    return errors - max_total, target_total, input_total


def _wil_compute(
    correct_total: torch.Tensor, target_total: torch.Tensor, preds_total: torch.Tensor
) -> torch.Tensor:
    """Compute the Word Information Lost.
    Args:
        correct_total: Number of correct words
        target_total: Number of words overall references
        preds_total: Number of words overall prediction
    Returns:
        Word Information Lost score
    """
    return 1 - ((correct_total / target_total) * (correct_total / preds_total))


[docs]@torch.inference_mode() def word_information_lost( input: Union[str, List[str]], target: Union[str, List[str]], ) -> torch.Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score. Its class version is ``torcheval.metrics.WordInformationLost``. Args: input: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings Returns: Word Information Lost rate Examples: >>> from torcheval.metrics.functional import word_information_lost >>> input = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> word_information_lost(input, target) tensor(0.6528) """ correct_total, target_total, preds_total = _wil_update(input, target) return _wil_compute(correct_total, target_total, preds_total)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources