Source code for torchaudio.models.decoder._ctc_decoder
import itertools as it
from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Union
import torch
from torchaudio._torchaudio_decoder import (
_create_word_dict,
_CriterionType,
_Dictionary,
_KenLM,
_LexiconDecoder,
_LexiconDecoderOptions,
_LexiconFreeDecoder,
_LexiconFreeDecoderOptions,
_LM,
_load_words,
_SmearingMode,
_Trie,
_ZeroLM,
)
from torchaudio.utils import download_asset
__all__ = ["CTCHypothesis", "CTCDecoder", "ctc_decoder", "download_pretrained_files"]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
[docs]class CTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func:`CTCDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where
`L` is the length of the output sequence
:ivar List[str] words: List of predicted words
:ivar float score: Score corresponding to hypothesis
:ivar torch.IntTensor timesteps: Timesteps corresponding to the tokens. Shape `(L, )`,
where `L` is the length of the output sequence
"""
tokens: torch.LongTensor
words: List[str]
score: float
timesteps: torch.IntTensor
[docs]class CTCDecoder:
"""
.. devices:: CPU
CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`].
Note:
To build the decoder, please use the factory function :py:func:`ctc_decoder`.
Args:
nbest (int): number of best decodings to return
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens
lm (_LM): language model
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence
unk_word (str): word corresponding to unknown
"""
def __init__(
self,
nbest: int,
lexicon: Optional[Dict],
word_dict: _Dictionary,
tokens_dict: _Dictionary,
lm: _LM,
decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
blank_token: str,
sil_token: str,
unk_word: str,
) -> None:
self.nbest = nbest
self.word_dict = word_dict
self.tokens_dict = tokens_dict
self.blank = self.tokens_dict.get_index(blank_token)
silence = self.tokens_dict.get_index(sil_token)
if lexicon:
unk_word = word_dict.get_index(unk_word)
vocab_size = self.tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
self.decoder = _LexiconDecoder(
decoder_options,
trie,
lm,
silence,
self.blank,
unk_word,
[],
False, # word level LM
)
else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, [])
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
"""Returns frame numbers corresponding to non-blank tokens."""
timesteps = []
for i, idx in enumerate(idxs):
if idx == self.blank:
continue
if i == 0 or idx != idxs[i - 1]:
timesteps.append(i)
return torch.IntTensor(timesteps)
[docs] def __call__(
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
) -> List[List[CTCHypothesis]]:
# Overriding the signature so that the return type is correct on Sphinx
"""__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \
List[List[torchaudio.models.decoder.CTCHypothesis]]
Args:
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model.
lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch.
Returns:
List[List[CTCHypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
"""
if emissions.dtype != torch.float32:
raise ValueError("emissions must be float32.")
if emissions.is_cuda:
raise RuntimeError("emissions must be a CPU tensor.")
if lengths is not None and lengths.is_cuda:
raise RuntimeError("lengths must be a CPU tensor.")
B, T, N = emissions.size()
if lengths is None:
lengths = torch.full((B,), T)
float_bytes = 4
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, lengths[b], N)
nbest_results = results[: self.nbest]
hypos.append(
[
CTCHypothesis(
tokens=self._get_tokens(result.tokens),
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
score=result.score,
timesteps=self._get_timesteps(result.tokens),
)
for result in nbest_results
]
)
return hypos
[docs] def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
"""
Map raw token IDs into corresponding tokens
Args:
idxs (LongTensor): raw token IDs generated from decoder
Returns:
List: tokens corresponding to the input IDs
"""
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
[docs]def ctc_decoder(
lexicon: Optional[str],
tokens: Union[str, List[str]],
lm: Optional[str] = None,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> CTCDecoder:
"""
Builds CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`].
Args:
lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
decoding.
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
lm (str or None, optional): file containing language model, or `None` if not using a language model
nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
beam_size_token (int, optional): max number of tokens to consider at each decode step.
If `None`, it is set to the total number of tokens (Default: None)
beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
lm_weight (float, optional): weight of language model (Default: 2)
word_score (float, optional): word insertion score (Default: 0)
unk_score (float, optional): unknown word insertion score (Default: -inf)
sil_score (float, optional): silence insertion score (Default: 0)
log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
blank_token (str, optional): token corresponding to blank (Default: "-")
sil_token (str, optional): token corresponding to silence (Default: "|")
unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
Returns:
CTCDecoder: decoder
Example
>>> decoder = ctc_decoder(
>>> lexicon="lexicon.txt",
>>> tokens="tokens.txt",
>>> lm="kenlm.bin",
>>> )
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
"""
tokens_dict = _Dictionary(tokens)
if lexicon is not None:
lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
else:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconFreeDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
return CTCDecoder(
nbest=nbest,
lexicon=lexicon,
word_dict=word_dict,
tokens_dict=tokens_dict,
lm=lm,
decoder_options=decoder_options,
blank_token=blank_token,
sil_token=sil_token,
unk_word=unk_word,
)
def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError(
f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
)
prefix = f"decoder-assets/{model}"
return _PretrainedFiles(
lexicon=f"{prefix}/lexicon.txt",
tokens=f"{prefix}/tokens.txt",
lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
)
[docs]def download_pretrained_files(model: str) -> _PretrainedFiles:
"""
Retrieves pretrained data files used for CTC decoder.
Args:
model (str): pretrained language model to download.
Options: ["librispeech-3-gram", "librispeech-4-gram", "librispeech"]
Returns:
Object with the following attributes
lm:
path corresponding to downloaded language model, or `None` if the model is not associated with an lm
lexicon:
path corresponding to downloaded lexicon file
tokens:
path corresponding to downloaded tokens file
"""
files = _get_filenames(model)
lexicon_file = download_asset(files.lexicon)
tokens_file = download_asset(files.tokens)
if files.lm is not None:
lm_file = download_asset(files.lm)
else:
lm_file = None
return _PretrainedFiles(
lexicon=lexicon_file,
tokens=tokens_file,
lm=lm_file,
)