Source code for torcheval.metrics.functional.text.bleu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections import Counter as counter
from typing import Counter, Optional, Sequence, Tuple, Union
import torch
[docs]def bleu_score(
input: Union[str, Sequence[str]],
target: Sequence[Union[str, Sequence[str]]],
n_gram: int = 4,
weights: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Compute BLEU score given translations and references for each translation.
Its class version is ``torcheval.metrics.texBLEUScore``.
Args:
input: Translations to score.
target: List of references for each translation. Requires len(input) = len(target)
n_gram: Maximum n-gram to use when computing BLEU score. Can be 1, 2, 3, or 4.
weights: Optional weight distribution of n-grams. Requires len(weights) = n_gram. If unspecified,
will use uniform weights.
Examples:
>>> import torch
>>> from torcheval.metrics.functional.text import bleu
>>> candidates = ["the squirrel is eating the nut"]
>>> references = [["a squirrel is eating a nut", "the squirrel is eating a tasty nut"]]
>>> bleu_score(candidates, references, n_gram=4)
tensor(0.53728497)
>>> candidates = ["the squirrel is eating the nut", "the cat is on the mat"]
>>> references = [["a squirrel is eating a nut", "the squirrel is eating a tasty nut"], ["there is a cat on the mat", "a cat is on the mat"]]
>>> bleu_score(candidates, references, n_gram=4)
tensor(0.65341892)
"""
(
input_len,
target_len,
matches_by_order,
possible_matches_by_order,
) = _bleu_score_update(
input,
target,
n_gram,
device,
)
return _bleu_score_compute(
input_len,
target_len,
matches_by_order,
possible_matches_by_order,
n_gram,
weights,
)
def _bleu_score_update(
input: Union[str, Sequence[str]],
target: Sequence[Union[str, Sequence[str]]],
n_gram: int,
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
input_ = [input] if isinstance(input, str) else input
target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target]
if len(input_) != len(target_):
raise ValueError(
f"Input and target corpus should have same sizes, but input corpus size = {len(input_)}, target corpus size = {len(target_)} "
)
input_len = torch.tensor(0, device=device)
target_len = torch.tensor(0, device=device)
matches_by_order = torch.zeros(n_gram, device=device)
possible_matches_by_order = torch.zeros(n_gram, device=device)
for (candidate, references) in zip(input_, target_):
candidate_tokenized = candidate.split()
references_tokenized = [ref.split() for ref in references]
len_candidate = len(candidate_tokenized)
len_reference = min([len(ref) for ref in references_tokenized])
input_len += len_candidate
target_len += len_reference
candidate_ngram_counter = _get_ngrams(candidate_tokenized, n_gram)
reference_ngram_counter = counter()
for ref in references_tokenized:
reference_ngram_counter |= _get_ngrams(ref, n_gram)
overlap = candidate_ngram_counter & reference_ngram_counter
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for i in range(n_gram):
if len_candidate - i > 0:
possible_matches_by_order[i] += len_candidate - i
if torch.min(possible_matches_by_order) == 0:
raise ValueError(
f"the input is too short to find all n-gram matches with n_gram={n_gram}"
)
return input_len, target_len, matches_by_order, possible_matches_by_order
def _bleu_score_compute(
input_len: torch.Tensor,
target_len: torch.Tensor,
matches_by_order: torch.Tensor,
possible_matches_by_order: torch.Tensor,
n_gram: int,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if weights is not None and n_gram != weights.size(dim=0):
raise ValueError(
f"the length of weights should equal n_gram, got len(weights)={weights.size(dim=0)}, n_gram={n_gram}"
)
if weights is None:
weights = torch.tensor([1 / n_gram] * n_gram)
precisions = matches_by_order / possible_matches_by_order
geometric_mean = torch.exp(torch.sum(weights * torch.log(precisions)))
brevity_penalty = _calc_brevity_penalty(input_len, target_len)
return brevity_penalty * geometric_mean
def _calc_brevity_penalty(
input_len: torch.Tensor, target_len: torch.Tensor
) -> torch.Tensor:
if input_len > target_len:
return torch.tensor(1.0, device=input_len.device)
else:
return torch.exp(1 - target_len / input_len)
def _get_ngrams(sentence: Sequence[str], n_gram: int) -> Counter[str]:
"""
Args:
sentence: text from which we get n-grams
n_gram: length of n-gram
"""
if n_gram not in [1, 2, 3, 4]:
raise ValueError(f"n_gram should be 1, 2, 3, or 4, got {n_gram}.")
ngram_counts = counter()
for n_val in range(1, n_gram + 1):
for i in range(0, len(sentence) - n_val + 1):
ngram = tuple(sentence[i : i + n_val])
ngram_counts[ngram] += 1
return ngram_counts