Source code for torchtext.data.metrics
import math
import collections
import torch
from torchtext.data.utils import ngrams_iterator
def _compute_ngram_counter(tokens, max_n):
""" Create a Counter with a count of unique n-grams in the tokens list
Arguments:
tokens: a list of tokens (typically a string split on whitespaces)
max_n: the maximum order of n-gram wanted
Outputs:
output: a collections.Counter object with the unique n-grams and their
associated count
Examples:
>>> from torchtext.data.metrics import _compute_ngram_counter
>>> tokens = ['me', 'me', 'you']
>>> _compute_ngram_counter(tokens, 2)
Counter({('me',): 2,
('you',): 1,
('me', 'me'): 1,
('me', 'you'): 1,
('me', 'me', 'you'): 1})
"""
assert max_n > 0
ngrams_counter = collections.Counter(tuple(x.split(' '))
for x in ngrams_iterator(tokens, max_n))
return ngrams_counter
[docs]def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4):
"""Computes the BLEU score between a candidate translation corpus and a references
translation corpus. Based on https://www.aclweb.org/anthology/P02-1040.pdf
Arguments:
candidate_corpus: an iterable of candidate translations. Each translation is an
iterable of tokens
references_corpus: an iterable of iterables of reference translations. Each
translation is an iterable of tokens
max_n: the maximum n-gram we want to use. E.g. if max_n=3, we will use unigrams,
bigrams and trigrams
weights: a list of weights used for each n-gram category (uniform by default)
Examples:
>>> from torchtext.data.metrics import bleu_score
>>> candidate_corpus = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']]
>>> references_corpus = [[['My', 'full', 'pytorch', 'test'], ['Completely', 'Different']], [['No', 'Match']]]
>>> bleu_score(candidate_corpus, references_corpus)
0.8408964276313782
"""
assert max_n == len(weights), 'Length of the "weights" list has be equal to max_n'
assert len(candidate_corpus) == len(references_corpus),\
'The length of candidate and reference corpus should be the same'
clipped_counts = torch.zeros(max_n)
total_counts = torch.zeros(max_n)
weights = torch.tensor(weights)
candidate_len = 0.0
refs_len = 0.0
for (candidate, refs) in zip(candidate_corpus, references_corpus):
candidate_len += len(candidate)
# Get the length of the reference that's closest in length to the candidate
refs_len_list = [float(len(ref)) for ref in refs]
refs_len += min(refs_len_list, key=lambda x: abs(len(candidate) - x))
reference_counters = _compute_ngram_counter(refs[0], max_n)
for ref in refs[1:]:
reference_counters = reference_counters | _compute_ngram_counter(ref, max_n)
candidate_counter = _compute_ngram_counter(candidate, max_n)
clipped_counter = candidate_counter & reference_counters
for ngram in clipped_counter:
clipped_counts[len(ngram) - 1] += clipped_counter[ngram]
for ngram in candidate_counter: # TODO: no need to loop through the whole counter
total_counts[len(ngram) - 1] += candidate_counter[ngram]
if min(clipped_counts) == 0:
return 0.0
else:
pn = clipped_counts / total_counts
log_pn = weights * torch.log(pn)
score = torch.exp(sum(log_pn))
bp = math.exp(min(1 - refs_len / candidate_len, 0))
return bp * score.item()