.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/asr_inference_with_ctc_decoder_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_asr_inference_with_ctc_decoder_tutorial.py: ASR Inference with CTC Decoder ============================== **Author**: `Caroline Chen `__ This tutorial shows how to perform speech recognition inference using a CTC beam search decoder with lexicon constraint and KenLM language model support. We demonstrate this on a pretrained wav2vec 2.0 model trained using CTC loss. .. GENERATED FROM PYTHON SOURCE LINES 15-41 Overview -------- Beam search decoding works by iteratively expanding text hypotheses (beams) with next possible characters, and maintaining only the hypotheses with the highest scores at each time step. A language model can be incorporated into the scoring computation, and adding a lexicon constraint restricts the next possible tokens for the hypotheses so that only words from the lexicon can be generated. The underlying implementation is ported from `Flashlight `__'s beam search decoder. A mathematical formula for the decoder optimization can be found in the `Wav2Letter paper `__, and a more detailed algorithm can be found in this `blog `__. Running ASR inference using a CTC Beam Search decoder with a KenLM language model and lexicon constraint requires the following components - Acoustic Model: model predicting phonetics from audio waveforms - Tokens: the possible predicted tokens from the acoustic model - Lexicon: mapping between possible words and their corresponding tokens sequence - KenLM: n-gram language model trained with the `KenLM library `__ .. GENERATED FROM PYTHON SOURCE LINES 44-50 Preparation ----------- First we import the necessary utilities and fetch the data that we are working with .. GENERATED FROM PYTHON SOURCE LINES 50-80 .. code-block:: default import time from typing import List import IPython import matplotlib.pyplot as plt import torch import torchaudio try: from torchaudio.models.decoder import ctc_decoder except ModuleNotFoundError: try: import google.colab print( """ To enable running this notebook in Google Colab, install nightly torch and torchaudio builds by adding the following code block to the top of the notebook before running it: !pip3 uninstall -y torch torchvision torchaudio !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu """ ) except ModuleNotFoundError: pass raise .. GENERATED FROM PYTHON SOURCE LINES 81-91 Acoustic Model and Data ~~~~~~~~~~~~~~~~~~~~~~~ We use the pretrained `Wav2Vec 2.0 `__ Base model that is finetuned on 10 min of the `LibriSpeech dataset `__, which can be loaded in using :py:func:`torchaudio.pipelines`. For more detail on running Wav2Vec 2.0 speech recognition pipelines in torchaudio, please refer to `this tutorial `__. .. GENERATED FROM PYTHON SOURCE LINES 91-96 .. code-block:: default bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M acoustic_model = bundle.get_model() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth 0%| | 0.00/360M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 111-115 The transcript corresponding to this audio file is :: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 115-122 .. code-block:: default waveform, sample_rate = torchaudio.load(speech_file) if sample_rate != bundle.sample_rate: waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) .. GENERATED FROM PYTHON SOURCE LINES 123-131 Files and Data for Decoder ~~~~~~~~~~~~~~~~~~~~~~~~~~ Next, we load in our token, lexicon, and KenLM data, which are used by the decoder to predict words from the acoustic model output. Pretrained files for the LibriSpeech dataset can be downloaded through torchaudio, or the user can provide their own files. .. GENERATED FROM PYTHON SOURCE LINES 134-151 Tokens ^^^^^^ The tokens are the possible symbols that the acoustic model can predict, including the blank and silent symbols. It can either be passed in as a file, where each line consists of the tokens corresponding to the same index, or as a list of tokens, each mapping to a unique index. :: # tokens.txt _ | e t ... .. GENERATED FROM PYTHON SOURCE LINES 151-156 .. code-block:: default tokens = [label.lower() for label in bundle.get_labels()] print(tokens) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none ['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z'] .. GENERATED FROM PYTHON SOURCE LINES 157-174 Lexicon ^^^^^^^ The lexicon is a mapping from words to their corresponding tokens sequence, and is used to restrict the search space of the decoder to only words from the lexicon. The expected format of the lexicon file is a line per word, with a word followed by its space-split tokens. :: # lexcion.txt a a | able a b l e | about a b o u t | ... ... .. GENERATED FROM PYTHON SOURCE LINES 177-188 KenLM ^^^^^ This is an n-gram language model trained with the `KenLM library `__. Both the ``.arpa`` or the binarized ``.bin`` LM can be used, but the binary format is recommended for faster loading. The language model used in this tutorial is a 4-gram KenLM trained using `LibriSpeech `__. .. GENERATED FROM PYTHON SOURCE LINES 191-200 Downloading Pretrained Files ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Pretrained files for the LibriSpeech dataset can be downloaded using :py:func:`download_pretrained_files `. Note: this cell may take a couple of minutes to run, as the language model can be large .. GENERATED FROM PYTHON SOURCE LINES 200-208 .. code-block:: default from torchaudio.models.decoder import download_pretrained_files files = download_pretrained_files("librispeech-4-gram") print(files) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0%| | 0.00/4.97M [00:00`. In addition to the previously mentioned components, it also takes in various beam search decoding parameters and token/word parameters. This decoder can also be run without a language model by passing in `None` into the `lm` parameter. .. GENERATED FROM PYTHON SOURCE LINES 227-244 .. code-block:: default from torchaudio.models.decoder import ctc_decoder LM_WEIGHT = 3.23 WORD_SCORE = -0.26 beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, nbest=3, beam_size=1500, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) .. GENERATED FROM PYTHON SOURCE LINES 245-250 Greedy Decoder ~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 250-276 .. code-block:: default class GreedyCTCDecoder(torch.nn.Module): def __init__(self, labels, blank=0): super().__init__() self.labels = labels self.blank = blank def forward(self, emission: torch.Tensor) -> List[str]: """Given a sequence emission over labels, get the best path Args: emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. Returns: List[str]: The resulting transcript """ indices = torch.argmax(emission, dim=-1) # [num_seq,] indices = torch.unique_consecutive(indices, dim=-1) indices = [i for i in indices if i != self.blank] joined = "".join([self.labels[i] for i in indices]) return joined.replace("|", " ").strip().split() greedy_decoder = GreedyCTCDecoder(tokens) .. GENERATED FROM PYTHON SOURCE LINES 277-289 Run Inference ------------- Now that we have the data, acoustic model, and decoder, we can perform inference. The output of the beam search decoder is of type :py:func:`torchaudio.models.decoder.CTCHypothesis`, consisting of the predicted token IDs, corresponding words, hypothesis score, and timesteps corresponding to the token IDs. Recall the transcript corresponding to the waveform is :: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 289-296 .. code-block:: default actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said" actual_transcript = actual_transcript.split() emission, _ = acoustic_model(waveform) .. GENERATED FROM PYTHON SOURCE LINES 297-299 The greedy decoder give the following result. .. GENERATED FROM PYTHON SOURCE LINES 299-308 .. code-block:: default greedy_result = greedy_decoder(emission[0]) greedy_transcript = " ".join(greedy_result) greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript) print(f"Transcript: {greedy_transcript}") print(f"WER: {greedy_wer}") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid WER: 0.38095238095238093 .. GENERATED FROM PYTHON SOURCE LINES 309-311 Using the beam search decoder: .. GENERATED FROM PYTHON SOURCE LINES 311-322 .. code-block:: default beam_search_result = beam_search_decoder(emission) beam_search_transcript = " ".join(beam_search_result[0][0].words).strip() beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len( actual_transcript ) print(f"Transcript: {beam_search_transcript}") print(f"WER: {beam_search_wer}") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said WER: 0.047619047619047616 .. GENERATED FROM PYTHON SOURCE LINES 323-328 We see that the transcript with the lexicon-constrained beam search decoder produces a more accurate result consisting of real words, while the greedy decoder can predict incorrectly spelled words like “affrayd” and “shoktd”. .. GENERATED FROM PYTHON SOURCE LINES 331-336 Timestep Alignments ------------------- Recall that one of the components of the resulting Hypotheses is timesteps corresponding to the token IDs. .. GENERATED FROM PYTHON SOURCE LINES 336-344 .. code-block:: default timesteps = beam_search_result[0][0].timesteps predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens) print(predicted_tokens, len(predicted_tokens)) print(timesteps, timesteps.shape[0]) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none ['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99 tensor([ 0, 31, 33, 36, 39, 41, 42, 44, 46, 48, 49, 52, 54, 58, 64, 66, 69, 73, 74, 76, 80, 82, 84, 86, 88, 94, 97, 107, 111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157, 159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191, 193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250, 251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289, 292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331, 353], dtype=torch.int32) 99 .. GENERATED FROM PYTHON SOURCE LINES 345-347 Below, we visualize the token timestep alignments relative to the original waveform. .. GENERATED FROM PYTHON SOURCE LINES 347-375 .. code-block:: default def plot_alignments(waveform, emission, tokens, timesteps): fig, ax = plt.subplots(figsize=(32, 10)) ax.plot(waveform) ratio = waveform.shape[0] / emission.shape[1] word_start = 0 for i in range(len(tokens)): if i != 0 and tokens[i - 1] == "|": word_start = timesteps[i] if tokens[i] != "|": plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14) elif i != 0: word_end = timesteps[i] ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red") xticks = ax.get_xticks() plt.xticks(xticks, xticks / bundle.sample_rate) ax.set_xlabel("time (sec)") ax.set_xlim(0, waveform.shape[0]) plot_alignments(waveform[0], emission, predicted_tokens, timesteps) .. image-sg:: /tutorials/images/sphx_glr_asr_inference_with_ctc_decoder_tutorial_001.png :alt: asr inference with ctc decoder tutorial :srcset: /tutorials/images/sphx_glr_asr_inference_with_ctc_decoder_tutorial_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 376-384 Beam Search Decoder Parameters ------------------------------ In this section, we go a little bit more in depth about some different parameters and tradeoffs. For the full list of customizable parameters, please refer to the :py:func:`documentation `. .. GENERATED FROM PYTHON SOURCE LINES 387-390 Helper Function ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 390-402 .. code-block:: default def print_decoded(decoder, emission, param, param_value): start_time = time.monotonic() result = decoder(emission) decode_time = time.monotonic() - start_time transcript = " ".join(result[0][0].words).lower().strip() score = result[0][0].score print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)") .. GENERATED FROM PYTHON SOURCE LINES 403-411 nbest ~~~~~ This parameter indicates the number of best hypotheses to return, which is a property that is not possible with the greedy decoder. For instance, by setting ``nbest=3`` when constructing the beam search decoder earlier, we can now access the hypotheses with the top 3 scores. .. GENERATED FROM PYTHON SOURCE LINES 411-418 .. code-block:: default for i in range(3): transcript = " ".join(beam_search_result[0][i].words).strip() score = beam_search_result[0][i].score print(f"{transcript} (score: {score})") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.8238231825794) i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.8580900895563) i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.015467226502) .. GENERATED FROM PYTHON SOURCE LINES 419-433 beam size ~~~~~~~~~ The ``beam_size`` parameter determines the maximum number of best hypotheses to hold after each decoding step. Using larger beam sizes allows for exploring a larger range of possible hypotheses which can produce hypotheses with higher scores, but it is computationally more expensive and does not provide additional gains beyond a certain point. In the example below, we see improvement in decoding quality as we increase beam size from 1 to 5 to 50, but notice how using a beam size of 500 provides the same output as beam size 50 while increase the computation time. .. GENERATED FROM PYTHON SOURCE LINES 433-449 .. code-block:: default beam_sizes = [1, 5, 50, 500] for beam_size in beam_sizes: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_size=beam_size, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam size", beam_size) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none beam size 1 : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.2201 secs) beam size 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0646 secs) beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2912 secs) beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.7506 secs) .. GENERATED FROM PYTHON SOURCE LINES 450-458 beam size token ~~~~~~~~~~~~~~~ The ``beam_size_token`` parameter corresponds to the number of tokens to consider for expanding each hypothesis at the decoding step. Exploring a larger number of next possible tokens increases the range of potential hypotheses at the cost of computation. .. GENERATED FROM PYTHON SOURCE LINES 458-475 .. code-block:: default num_tokens = len(tokens) beam_size_tokens = [1, 5, 10, num_tokens] for beam_size_token in beam_size_tokens: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_size_token=beam_size_token, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none beam size token 1 : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.3286 secs) beam size token 5 : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.2777 secs) beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.2314 secs) beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4088 secs) .. GENERATED FROM PYTHON SOURCE LINES 476-486 beam threshold ~~~~~~~~~~~~~~ The ``beam_threshold`` parameter is used to prune the stored hypotheses set at each decoding step, removing hypotheses whose scores are greater than ``beam_threshold`` away from the highest scoring hypothesis. There is a balance between choosing smaller thresholds to prune more hypotheses and reduce the search space, and choosing a large enough threshold such that plausible hypotheses are not pruned. .. GENERATED FROM PYTHON SOURCE LINES 486-502 .. code-block:: default beam_thresholds = [1, 5, 10, 25] for beam_threshold in beam_thresholds: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_threshold=beam_threshold, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none beam threshold 1 : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.0337 secs) beam threshold 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.0850 secs) beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3094 secs) beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2865 secs) .. GENERATED FROM PYTHON SOURCE LINES 503-512 language model weight ~~~~~~~~~~~~~~~~~~~~~ The ``lm_weight`` parameter is the weight to assign to the language model score which to accumulate with the acoustic model score for determining the overall scores. Larger weights encourage the model to predict next words based on the language model, while smaller weights give more weight to the acoustic model score instead. .. GENERATED FROM PYTHON SOURCE LINES 512-527 .. code-block:: default lm_weights = [0, LM_WEIGHT, 15] for lm_weight in lm_weights: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, lm_weight=lm_weight, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "lm weight", lm_weight) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none lm weight 0 : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.3061 secs) lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3269 secs) lm weight 15 : was there in his was at some of what he said (score: 2918.98; 0.3175 secs) .. GENERATED FROM PYTHON SOURCE LINES 528-538 additional parameters ~~~~~~~~~~~~~~~~~~~~~ Additional parameters that can be optimized include the following - ``word_score``: score to add when word finishes - ``unk_score``: unknown word appearance score to add - ``sil_score``: silence appearance score to add - ``log_add``: whether to use log add for lexicon Trie smearing .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 3 minutes 5.586 seconds) .. _sphx_glr_download_tutorials_asr_inference_with_ctc_decoder_tutorial.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: asr_inference_with_ctc_decoder_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: asr_inference_with_ctc_decoder_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_