Shortcuts

Source code for torchaudio.models.decoder._ctc_decoder

import itertools as it
from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Union

import torch
from torchaudio._torchaudio_decoder import (
    _create_word_dict,
    _CriterionType,
    _Dictionary,
    _KenLM,
    _LexiconDecoder,
    _LexiconDecoderOptions,
    _LexiconFreeDecoder,
    _LexiconFreeDecoderOptions,
    _LM,
    _load_words,
    _SmearingMode,
    _Trie,
    _ZeroLM,
)
from torchaudio.utils import download_asset

__all__ = ["CTCHypothesis", "CTCDecoder", "ctc_decoder", "download_pretrained_files"]


_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])


[docs]class CTCHypothesis(NamedTuple): r"""Represents hypothesis generated by CTC beam search decoder :py:func:`CTCDecoder`. :ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence :ivar List[str] words: List of predicted words :ivar float score: Score corresponding to hypothesis :ivar torch.IntTensor timesteps: Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence """ tokens: torch.LongTensor words: List[str] score: float timesteps: torch.IntTensor
[docs]class CTCDecoder: """ .. devices:: CPU CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`]. Note: To build the decoder, please use the factory function :py:func:`ctc_decoder`. Args: nbest (int): number of best decodings to return lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder word_dict (_Dictionary): dictionary of words tokens_dict (_Dictionary): dictionary of tokens lm (_LM): language model decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding blank_token (str): token corresopnding to blank sil_token (str): token corresponding to silence unk_word (str): word corresponding to unknown """ def __init__( self, nbest: int, lexicon: Optional[Dict], word_dict: _Dictionary, tokens_dict: _Dictionary, lm: _LM, decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions], blank_token: str, sil_token: str, unk_word: str, ) -> None: self.nbest = nbest self.word_dict = word_dict self.tokens_dict = tokens_dict self.blank = self.tokens_dict.get_index(blank_token) silence = self.tokens_dict.get_index(sil_token) if lexicon: unk_word = word_dict.get_index(unk_word) vocab_size = self.tokens_dict.index_size() trie = _Trie(vocab_size, silence) start_state = lm.start(False) for word, spellings in lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = lm.score(start_state, word_idx) for spelling in spellings: spelling_idx = [self.tokens_dict.get_index(token) for token in spelling] trie.insert(spelling_idx, word_idx, score) trie.smear(_SmearingMode.MAX) self.decoder = _LexiconDecoder( decoder_options, trie, lm, silence, self.blank, unk_word, [], False, # word level LM ) else: self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, []) def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: idxs = (g[0] for g in it.groupby(idxs)) idxs = filter(lambda x: x != self.blank, idxs) return torch.LongTensor(list(idxs)) def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor: """Returns frame numbers corresponding to non-blank tokens.""" timesteps = [] for i, idx in enumerate(idxs): if idx == self.blank: continue if i == 0 or idx != idxs[i - 1]: timesteps.append(i) return torch.IntTensor(timesteps)
[docs] def __call__( self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None ) -> List[List[CTCHypothesis]]: # Overriding the signature so that the return type is correct on Sphinx """__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \ List[List[torchaudio.models.decoder.CTCHypothesis]] Args: emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of probability distribution over labels; output of acoustic model. lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of in time axis of the output Tensor in each batch. Returns: List[List[CTCHypothesis]]: List of sorted best hypotheses for each audio sequence in the batch. """ if emissions.dtype != torch.float32: raise ValueError("emissions must be float32.") if emissions.is_cuda: raise RuntimeError("emissions must be a CPU tensor.") if lengths is not None and lengths.is_cuda: raise RuntimeError("lengths must be a CPU tensor.") B, T, N = emissions.size() if lengths is None: lengths = torch.full((B,), T) float_bytes = 4 hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, lengths[b], N) nbest_results = results[: self.nbest] hypos.append( [ CTCHypothesis( tokens=self._get_tokens(result.tokens), words=[self.word_dict.get_entry(x) for x in result.words if x >= 0], score=result.score, timesteps=self._get_timesteps(result.tokens), ) for result in nbest_results ] ) return hypos
[docs] def idxs_to_tokens(self, idxs: torch.LongTensor) -> List: """ Map raw token IDs into corresponding tokens Args: idxs (LongTensor): raw token IDs generated from decoder Returns: List: tokens corresponding to the input IDs """ return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
[docs]def ctc_decoder( lexicon: Optional[str], tokens: Union[str, List[str]], lm: Optional[str] = None, nbest: int = 1, beam_size: int = 50, beam_size_token: Optional[int] = None, beam_threshold: float = 50, lm_weight: float = 2, word_score: float = 0, unk_score: float = float("-inf"), sil_score: float = 0, log_add: bool = False, blank_token: str = "-", sil_token: str = "|", unk_word: str = "<unk>", ) -> CTCDecoder: """ Builds CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`]. Args: lexicon (str or None): lexicon file containing the possible words and corresponding spellings. Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free decoding. tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected format is for tokens mapping to the same index to be on the same line lm (str or None, optional): file containing language model, or `None` if not using a language model nbest (int, optional): number of best decodings to return (Default: 1) beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50) beam_size_token (int, optional): max number of tokens to consider at each decode step. If `None`, it is set to the total number of tokens (Default: None) beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50) lm_weight (float, optional): weight of language model (Default: 2) word_score (float, optional): word insertion score (Default: 0) unk_score (float, optional): unknown word insertion score (Default: -inf) sil_score (float, optional): silence insertion score (Default: 0) log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False) blank_token (str, optional): token corresponding to blank (Default: "-") sil_token (str, optional): token corresponding to silence (Default: "|") unk_word (str, optional): word corresponding to unknown (Default: "<unk>") Returns: CTCDecoder: decoder Example >>> decoder = ctc_decoder( >>> lexicon="lexicon.txt", >>> tokens="tokens.txt", >>> lm="kenlm.bin", >>> ) >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses """ tokens_dict = _Dictionary(tokens) if lexicon is not None: lexicon = _load_words(lexicon) word_dict = _create_word_dict(lexicon) lm = _KenLM(lm, word_dict) if lm else _ZeroLM() decoder_options = _LexiconDecoderOptions( beam_size=beam_size, beam_size_token=beam_size_token or tokens_dict.index_size(), beam_threshold=beam_threshold, lm_weight=lm_weight, word_score=word_score, unk_score=unk_score, sil_score=sil_score, log_add=log_add, criterion_type=_CriterionType.CTC, ) else: d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())} d[unk_word] = [[unk_word]] word_dict = _create_word_dict(d) lm = _KenLM(lm, word_dict) if lm else _ZeroLM() decoder_options = _LexiconFreeDecoderOptions( beam_size=beam_size, beam_size_token=beam_size_token or tokens_dict.index_size(), beam_threshold=beam_threshold, lm_weight=lm_weight, sil_score=sil_score, log_add=log_add, criterion_type=_CriterionType.CTC, ) return CTCDecoder( nbest=nbest, lexicon=lexicon, word_dict=word_dict, tokens_dict=tokens_dict, lm=lm, decoder_options=decoder_options, blank_token=blank_token, sil_token=sil_token, unk_word=unk_word, )
def _get_filenames(model: str) -> _PretrainedFiles: if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]: raise ValueError( f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']" ) prefix = f"decoder-assets/{model}" return _PretrainedFiles( lexicon=f"{prefix}/lexicon.txt", tokens=f"{prefix}/tokens.txt", lm=f"{prefix}/lm.bin" if model != "librispeech" else None, )
[docs]def download_pretrained_files(model: str) -> _PretrainedFiles: """ Retrieves pretrained data files used for CTC decoder. Args: model (str): pretrained language model to download. Options: ["librispeech-3-gram", "librispeech-4-gram", "librispeech"] Returns: Object with the following attributes lm: path corresponding to downloaded language model, or `None` if the model is not associated with an lm lexicon: path corresponding to downloaded lexicon file tokens: path corresponding to downloaded tokens file """ files = _get_filenames(model) lexicon_file = download_asset(files.lexicon) tokens_file = download_asset(files.tokens) if files.lm is not None: lm_file = download_asset(files.lm) else: lm_file = None return _PretrainedFiles( lexicon=lexicon_file, tokens=tokens_file, lm=lm_file, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources