Source code for torchaudio.models.decoder._ctc_decoder
from __future__ import annotations
import itertools as it
from abc import abstractmethod
from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import torch
from flashlight.lib.text.decoder import (
CriterionType as _CriterionType,
LexiconDecoder as _LexiconDecoder,
LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as _LM,
LMState as _LMState,
SmearingMode as _SmearingMode,
Trie as _Trie,
ZeroLM as _ZeroLM,
)
from flashlight.lib.text.dictionary import (
create_word_dict as _create_word_dict,
Dictionary as _Dictionary,
load_words as _load_words,
)
from torchaudio.utils import download_asset
try:
from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
except Exception:
try:
from flashlight.lib.text.decoder import KenLM as _KenLM
except Exception:
_KenLM = None
__all__ = [
"CTCHypothesis",
"CTCDecoder",
"CTCDecoderLM",
"CTCDecoderLMState",
"ctc_decoder",
"download_pretrained_files",
]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
vocab_size = tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
return trie
def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
word_dict = None
if lm_dict is not None:
word_dict = _Dictionary(lm_dict)
if lexicon and word_dict is None:
word_dict = _create_word_dict(lexicon)
elif not lexicon and word_dict is None and type(lm) == str:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
return word_dict
[docs]class CTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
tokens: torch.LongTensor
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
words: List[str]
"""List of predicted words.
Note:
This attribute is only applicable if a lexicon is provided to the decoder. If
decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
:func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
"""
score: float
"""Score corresponding to hypothesis"""
timesteps: torch.IntTensor
"""Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""
[docs]class CTCDecoderLMState(_LMState):
"""Language model state."""
@property
def children(self) -> Dict[int, CTCDecoderLMState]:
"""Map of indices to LM states"""
return super().children
[docs] def child(self, usr_index: int) -> CTCDecoderLMState:
"""Returns child corresponding to usr_index, or creates and returns a new state if input index
is not found.
Args:
usr_index (int): index corresponding to child state
Returns:
CTCDecoderLMState: child state corresponding to usr_index
"""
return super().child(usr_index)
[docs] def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
"""Compare two language model states.
Args:
state (CTCDecoderLMState): LM state to compare against
Returns:
int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
"""
pass
[docs]class CTCDecoderLM(_LM):
"""Language model base class for creating custom language models to use with the decoder."""
[docs] @abstractmethod
def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
"""Initialize or reset the language model.
Args:
start_with_nothing (bool): whether or not to start sentence with sil token.
Returns:
CTCDecoderLMState: starting state
"""
raise NotImplementedError
[docs] @abstractmethod
def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
"""Evaluate the language model based on the current LM state and new word.
Args:
state (CTCDecoderLMState): current LM state
usr_token_idx (int): index of the word
Returns:
(CTCDecoderLMState, float)
CTCDecoderLMState:
new LM state
float:
score
"""
raise NotImplementedError
[docs] @abstractmethod
def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
"""Evaluate end for language model based on current LM state.
Args:
state (CTCDecoderLMState): current LM state
Returns:
(CTCDecoderLMState, float)
CTCDecoderLMState:
new LM state
float:
score
"""
raise NotImplementedError
[docs]class CTCDecoder:
"""CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
.. devices:: CPU
Note:
To build the decoder, please use the factory function :func:`ctc_decoder`.
"""
def __init__(
self,
nbest: int,
lexicon: Optional[Dict],
word_dict: _Dictionary,
tokens_dict: _Dictionary,
lm: CTCDecoderLM,
decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
blank_token: str,
sil_token: str,
unk_word: str,
) -> None:
"""
Args:
nbest (int): number of best decodings to return
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens
lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
parameters used for beam search decoding
blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence
unk_word (str): word corresponding to unknown
"""
self.nbest = nbest
self.word_dict = word_dict
self.tokens_dict = tokens_dict
self.blank = self.tokens_dict.get_index(blank_token)
silence = self.tokens_dict.get_index(sil_token)
transitions = []
if lexicon:
trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
unk_word = word_dict.get_index(unk_word)
token_lm = False # use word level LM
self.decoder = _LexiconDecoder(
decoder_options,
trie,
lm,
silence,
self.blank,
unk_word,
transitions,
token_lm,
)
else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
# https://github.com/pytorch/audio/issues/3218
# If lm is passed like rvalue reference, the lm object gets garbage collected,
# and later call to the lm fails.
# This ensures that lm object is not deleted as long as the decoder is alive.
# https://github.com/pybind/pybind11/discussions/4013
self.lm = lm
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
"""Returns frame numbers corresponding to non-blank tokens."""
timesteps = []
for i, idx in enumerate(idxs):
if idx == self.blank:
continue
if i == 0 or idx != idxs[i - 1]:
timesteps.append(i)
return torch.IntTensor(timesteps)
[docs] def decode_begin(self):
"""Initialize the internal state of the decoder.
See :py:meth:`decode_step` for the usage.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
"""
self.decoder.decode_begin()
[docs] def decode_end(self):
"""Finalize the internal state of the decoder.
See :py:meth:`decode_step` for the usage.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
"""
self.decoder.decode_end()
[docs] def decode_step(self, emissions: torch.FloatTensor):
"""Perform incremental decoding on top of the curent internal state.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
Args:
emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model.
Example:
>>> decoder = torchaudio.models.decoder.ctc_decoder(...)
>>> decoder.decode_begin()
>>> decoder.decode_step(emission1)
>>> decoder.decode_step(emission2)
>>> decoder.decode_end()
>>> result = decoder.get_final_hypothesis()
"""
if emissions.dtype != torch.float32:
raise ValueError("emissions must be float32.")
if not emissions.is_cpu:
raise RuntimeError("emissions must be a CPU tensor.")
if not emissions.is_contiguous():
raise RuntimeError("emissions must be contiguous.")
if emissions.ndim != 2:
raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
T, N = emissions.size()
self.decoder.decode_step(emissions.data_ptr(), T, N)
def _to_hypo(self, results) -> List[CTCHypothesis]:
return [
CTCHypothesis(
tokens=self._get_tokens(result.tokens),
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
score=result.score,
timesteps=self._get_timesteps(result.tokens),
)
for result in results
]
[docs] def get_final_hypothesis(self) -> List[CTCHypothesis]:
"""Get the final hypothesis
Returns:
List[CTCHypothesis]:
List of sorted best hypotheses.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
"""
results = self.decoder.get_all_final_hypothesis()
return self._to_hypo(results[: self.nbest])
[docs] def __call__(
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
) -> List[List[CTCHypothesis]]:
"""
Performs batched offline decoding.
.. note::
This method performs offline decoding in one go. To perform incremental decoding,
please refer to :py:meth:`decode_step`.
Args:
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model.
lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch.
Returns:
List[List[CTCHypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
"""
if emissions.dtype != torch.float32:
raise ValueError("emissions must be float32.")
if not emissions.is_cpu:
raise RuntimeError("emissions must be a CPU tensor.")
if not emissions.is_contiguous():
raise RuntimeError("emissions must be contiguous.")
if emissions.ndim != 3:
raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
if lengths is not None and not lengths.is_cpu:
raise RuntimeError("lengths must be a CPU tensor.")
B, T, N = emissions.size()
if lengths is None:
lengths = torch.full((B,), T)
float_bytes = 4
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, lengths[b], N)
hypos.append(self._to_hypo(results[: self.nbest]))
return hypos
[docs] def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
"""
Map raw token IDs into corresponding tokens
Args:
idxs (LongTensor): raw token IDs generated from decoder
Returns:
List: tokens corresponding to the input IDs
"""
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
[docs]def ctc_decoder(
lexicon: Optional[str],
tokens: Union[str, List[str]],
lm: Union[str, CTCDecoderLM] = None,
lm_dict: Optional[str] = None,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> CTCDecoder:
"""Builds an instance of :class:`CTCDecoder`.
Args:
lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
decoding.
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
custom language model of type `CTCDecoderLM`, or `None` if not using a language model
lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
(Default: None)
nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
beam_size_token (int, optional): max number of tokens to consider at each decode step.
If `None`, it is set to the total number of tokens (Default: None)
beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
lm_weight (float, optional): weight of language model (Default: 2)
word_score (float, optional): word insertion score (Default: 0)
unk_score (float, optional): unknown word insertion score (Default: -inf)
sil_score (float, optional): silence insertion score (Default: 0)
log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
blank_token (str, optional): token corresponding to blank (Default: "-")
sil_token (str, optional): token corresponding to silence (Default: "|")
unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
Returns:
CTCDecoder: decoder
Example
>>> decoder = ctc_decoder(
>>> lexicon="lexicon.txt",
>>> tokens="tokens.txt",
>>> lm="kenlm.bin",
>>> )
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
"""
if lm_dict is not None and type(lm_dict) is not str:
raise ValueError("lm_dict must be None or str type.")
tokens_dict = _Dictionary(tokens)
# decoder options
if lexicon:
lexicon = _load_words(lexicon)
decoder_options = _LexiconDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
else:
decoder_options = _LexiconFreeDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
# construct word dict and language model
word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
if type(lm) == str:
if _KenLM is None:
raise RuntimeError(
"flashlight-text is installed, but KenLM is not installed. "
"Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
)
lm = _KenLM(lm, word_dict)
elif lm is None:
lm = _ZeroLM()
return CTCDecoder(
nbest=nbest,
lexicon=lexicon,
word_dict=word_dict,
tokens_dict=tokens_dict,
lm=lm,
decoder_options=decoder_options,
blank_token=blank_token,
sil_token=sil_token,
unk_word=unk_word,
)
def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError(
f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
)
prefix = f"decoder-assets/{model}"
return _PretrainedFiles(
lexicon=f"{prefix}/lexicon.txt",
tokens=f"{prefix}/tokens.txt",
lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
)
[docs]def download_pretrained_files(model: str) -> _PretrainedFiles:
"""
Retrieves pretrained data files used for :func:`ctc_decoder`.
Args:
model (str): pretrained language model to download.
Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.
Returns:
Object with the following attributes
* ``lm``: path corresponding to downloaded language model,
or ``None`` if the model is not associated with an lm
* ``lexicon``: path corresponding to downloaded lexicon file
* ``tokens``: path corresponding to downloaded tokens file
"""
files = _get_filenames(model)
lexicon_file = download_asset(files.lexicon)
tokens_file = download_asset(files.tokens)
if files.lm is not None:
lm_file = download_asset(files.lm)
else:
lm_file = None
return _PretrainedFiles(
lexicon=lexicon_file,
tokens=tokens_file,
lm=lm_file,
)