.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/asr_inference_with_ctc_decoder_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_asr_inference_with_ctc_decoder_tutorial.py: ASR Inference with CTC Decoder ============================== **Author**: `Caroline Chen `__ This tutorial shows how to perform speech recognition inference using a CTC beam search decoder with lexicon constraint and KenLM language model support. We demonstrate this on a pretrained wav2vec 2.0 model trained using CTC loss. .. GENERATED FROM PYTHON SOURCE LINES 15-42 Overview -------- Beam search decoding works by iteratively expanding text hypotheses (beams) with next possible characters, and maintaining only the hypotheses with the highest scores at each time step. A language model can be incorporated into the scoring computation, and adding a lexicon constraint restricts the next possible tokens for the hypotheses so that only words from the lexicon can be generated. The underlying implementation is ported from `Flashlight `__'s beam search decoder. A mathematical formula for the decoder optimization can be found in the `Wav2Letter paper `__, and a more detailed algorithm can be found in this `blog `__. Running ASR inference using a CTC Beam Search decoder with a language model and lexicon constraint requires the following components - Acoustic Model: model predicting phonetics from audio waveforms - Tokens: the possible predicted tokens from the acoustic model - Lexicon: mapping between possible words and their corresponding tokens sequence - Language Model (LM): n-gram language model trained with the `KenLM library `__, or custom language model that inherits :py:class:`~torchaudio.models.decoder.CTCDecoderLM` .. GENERATED FROM PYTHON SOURCE LINES 45-51 Acoustic Model and Set Up ------------------------- First we import the necessary utilities and fetch the data that we are working with .. GENERATED FROM PYTHON SOURCE LINES 51-58 .. code-block:: default import torch import torchaudio print(torch.__version__) print(torchaudio.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.4.0.dev20240326 2.2.0.dev20240328 .. GENERATED FROM PYTHON SOURCE LINES 60-69 .. code-block:: default import time from typing import List import IPython import matplotlib.pyplot as plt from torchaudio.models.decoder import ctc_decoder from torchaudio.utils import download_asset .. GENERATED FROM PYTHON SOURCE LINES 70-78 We use the pretrained `Wav2Vec 2.0 `__ Base model that is finetuned on 10 min of the `LibriSpeech dataset `__, which can be loaded in using :data:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M`. For more detail on running Wav2Vec 2.0 speech recognition pipelines in torchaudio, please refer to `this tutorial <./speech_recognition_pipeline_tutorial.html>`__. .. GENERATED FROM PYTHON SOURCE LINES 79-84 .. code-block:: default bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M acoustic_model = bundle.get_model() .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth 0%| | 0.00/360M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 94-100 The transcript corresponding to this audio file is .. code-block:: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 100-107 .. code-block:: default waveform, sample_rate = torchaudio.load(speech_file) if sample_rate != bundle.sample_rate: waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) .. GENERATED FROM PYTHON SOURCE LINES 108-116 Files and Data for Decoder -------------------------- Next, we load in our token, lexicon, and language model data, which are used by the decoder to predict words from the acoustic model output. Pretrained files for the LibriSpeech dataset can be downloaded through torchaudio, or the user can provide their own files. .. GENERATED FROM PYTHON SOURCE LINES 119-136 Tokens ~~~~~~ The tokens are the possible symbols that the acoustic model can predict, including the blank and silent symbols. It can either be passed in as a file, where each line consists of the tokens corresponding to the same index, or as a list of tokens, each mapping to a unique index. .. code-block:: # tokens.txt _ | e t ... .. GENERATED FROM PYTHON SOURCE LINES 136-141 .. code-block:: default tokens = [label.lower() for label in bundle.get_labels()] print(tokens) .. rst-class:: sphx-glr-script-out .. code-block:: none ['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z'] .. GENERATED FROM PYTHON SOURCE LINES 142-159 Lexicon ~~~~~~~ The lexicon is a mapping from words to their corresponding tokens sequence, and is used to restrict the search space of the decoder to only words from the lexicon. The expected format of the lexicon file is a line per word, with a word followed by its space-split tokens. .. code-block:: # lexcion.txt a a | able a b l e | about a b o u t | ... ... .. GENERATED FROM PYTHON SOURCE LINES 162-170 Language Model ~~~~~~~~~~~~~~ A language model can be used in decoding to improve the results, by factoring in a language model score that represents the likelihood of the sequence into the beam search computation. Below, we outline the different forms of language models that are supported for decoding. .. GENERATED FROM PYTHON SOURCE LINES 172-178 No Language Model ^^^^^^^^^^^^^^^^^ To create a decoder instance without a language model, set `lm=None` when initializing the decoder. .. GENERATED FROM PYTHON SOURCE LINES 180-191 KenLM ^^^^^ This is an n-gram language model trained with the `KenLM library `__. Both the ``.arpa`` or the binarized ``.bin`` LM can be used, but the binary format is recommended for faster loading. The language model used in this tutorial is a 4-gram KenLM trained using `LibriSpeech `__. .. GENERATED FROM PYTHON SOURCE LINES 193-204 Custom Language Model ^^^^^^^^^^^^^^^^^^^^^ Users can define their own custom language model in Python, whether it be a statistical or neural network language model, using :py:class:`~torchaudio.models.decoder.CTCDecoderLM` and :py:class:`~torchaudio.models.decoder.CTCDecoderLMState`. For instance, the following code creates a basic wrapper around a PyTorch ``torch.nn.Module`` language model. .. GENERATED FROM PYTHON SOURCE LINES 204-240 .. code-block:: default from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState class CustomLM(CTCDecoderLM): """Create a Python wrapper around `language_model` to feed to the decoder.""" def __init__(self, language_model: torch.nn.Module): CTCDecoderLM.__init__(self) self.language_model = language_model self.sil = -1 # index for silent token in the language model self.states = {} language_model.eval() def start(self, start_with_nothing: bool = False): state = CTCDecoderLMState() with torch.no_grad(): score = self.language_model(self.sil) self.states[state] = score return state def score(self, state: CTCDecoderLMState, token_index: int): outstate = state.child(token_index) if outstate not in self.states: score = self.language_model(token_index) self.states[outstate] = score score = self.states[outstate] return outstate, score def finish(self, state: CTCDecoderLMState): return self.score(state, self.sil) .. GENERATED FROM PYTHON SOURCE LINES 241-250 Downloading Pretrained Files ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Pretrained files for the LibriSpeech dataset can be downloaded using :py:func:`~torchaudio.models.decoder.download_pretrained_files`. Note: this cell may take a couple of minutes to run, as the language model can be large .. GENERATED FROM PYTHON SOURCE LINES 250-258 .. code-block:: default from torchaudio.models.decoder import download_pretrained_files files = download_pretrained_files("librispeech-4-gram") print(files) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/4.97M [00:00 List[str]: """Given a sequence emission over labels, get the best path Args: emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. Returns: List[str]: The resulting transcript """ indices = torch.argmax(emission, dim=-1) # [num_seq,] indices = torch.unique_consecutive(indices, dim=-1) indices = [i for i in indices if i != self.blank] joined = "".join([self.labels[i] for i in indices]) return joined.replace("|", " ").strip().split() greedy_decoder = GreedyCTCDecoder(tokens) .. GENERATED FROM PYTHON SOURCE LINES 323-337 Run Inference ------------- Now that we have the data, acoustic model, and decoder, we can perform inference. The output of the beam search decoder is of type :py:class:`~torchaudio.models.decoder.CTCHypothesis`, consisting of the predicted token IDs, corresponding words (if a lexicon is provided), hypothesis score, and timesteps corresponding to the token IDs. Recall the transcript corresponding to the waveform is .. code-block:: i really was very much afraid of showing him how much shocked i was at some parts of what he said .. GENERATED FROM PYTHON SOURCE LINES 337-344 .. code-block:: default actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said" actual_transcript = actual_transcript.split() emission, _ = acoustic_model(waveform) .. GENERATED FROM PYTHON SOURCE LINES 345-347 The greedy decoder gives the following result. .. GENERATED FROM PYTHON SOURCE LINES 347-356 .. code-block:: default greedy_result = greedy_decoder(emission[0]) greedy_transcript = " ".join(greedy_result) greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript) print(f"Transcript: {greedy_transcript}") print(f"WER: {greedy_wer}") .. rst-class:: sphx-glr-script-out .. code-block:: none Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid WER: 0.38095238095238093 .. GENERATED FROM PYTHON SOURCE LINES 357-359 Using the beam search decoder: .. GENERATED FROM PYTHON SOURCE LINES 359-370 .. code-block:: default beam_search_result = beam_search_decoder(emission) beam_search_transcript = " ".join(beam_search_result[0][0].words).strip() beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len( actual_transcript ) print(f"Transcript: {beam_search_transcript}") print(f"WER: {beam_search_wer}") .. rst-class:: sphx-glr-script-out .. code-block:: none Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said WER: 0.047619047619047616 .. GENERATED FROM PYTHON SOURCE LINES 371-389 .. note:: The :py:attr:`~torchaudio.models.decoder.CTCHypothesis.words` field of the output hypotheses will be empty if no lexicon is provided to the decoder. To retrieve a transcript with lexicon-free decoding, you can perform the following to retrieve the token indices, convert them to original tokens, then join them together. .. code:: tokens_str = "".join(beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)) transcript = " ".join(tokens_str.split("|")) We see that the transcript with the lexicon-constrained beam search decoder produces a more accurate result consisting of real words, while the greedy decoder can predict incorrectly spelled words like “affrayd” and “shoktd”. .. GENERATED FROM PYTHON SOURCE LINES 391-399 Incremental decoding ~~~~~~~~~~~~~~~~~~~~ If the input speech is long, one can decode the emission in incremental manner. You need to first initialize the internal state of the decoder with :py:meth:`~torchaudio.models.decoder.CTCDecoder.decode_begin`. .. GENERATED FROM PYTHON SOURCE LINES 399-402 .. code-block:: default beam_search_decoder.decode_begin() .. GENERATED FROM PYTHON SOURCE LINES 403-407 Then, you can pass emissions to :py:meth:`~torchaudio.models.decoder.CTCDecoder.decode_begin`. Here we use the same emission but pass it to the decoder one frame at a time. .. GENERATED FROM PYTHON SOURCE LINES 407-411 .. code-block:: default for t in range(emission.size(1)): beam_search_decoder.decode_step(emission[0, t:t + 1, :]) .. GENERATED FROM PYTHON SOURCE LINES 412-414 Finally, finalize the internal state of the decoder, and retrieve the result. .. GENERATED FROM PYTHON SOURCE LINES 414-418 .. code-block:: default beam_search_decoder.decode_end() beam_search_result_inc = beam_search_decoder.get_final_hypothesis() .. GENERATED FROM PYTHON SOURCE LINES 419-421 The result of incremental decoding is identical to batch decoding. .. GENERATED FROM PYTHON SOURCE LINES 421-432 .. code-block:: default beam_search_transcript_inc = " ".join(beam_search_result_inc[0].words).strip() beam_search_wer_inc = torchaudio.functional.edit_distance( actual_transcript, beam_search_result_inc[0].words) / len(actual_transcript) print(f"Transcript: {beam_search_transcript_inc}") print(f"WER: {beam_search_wer_inc}") assert beam_search_result[0][0].words == beam_search_result_inc[0].words assert beam_search_result[0][0].score == beam_search_result_inc[0].score torch.testing.assert_close(beam_search_result[0][0].timesteps, beam_search_result_inc[0].timesteps) .. rst-class:: sphx-glr-script-out .. code-block:: none Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said WER: 0.047619047619047616 .. GENERATED FROM PYTHON SOURCE LINES 433-438 Timestep Alignments ------------------- Recall that one of the components of the resulting Hypotheses is timesteps corresponding to the token IDs. .. GENERATED FROM PYTHON SOURCE LINES 438-446 .. code-block:: default timesteps = beam_search_result[0][0].timesteps predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens) print(predicted_tokens, len(predicted_tokens)) print(timesteps, timesteps.shape[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none ['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99 tensor([ 0, 31, 33, 36, 39, 41, 42, 44, 46, 48, 49, 52, 54, 58, 64, 66, 69, 73, 74, 76, 80, 82, 84, 86, 88, 94, 97, 107, 111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157, 159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191, 193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250, 251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289, 292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331, 353], dtype=torch.int32) 99 .. GENERATED FROM PYTHON SOURCE LINES 447-449 Below, we visualize the token timestep alignments relative to the original waveform. .. GENERATED FROM PYTHON SOURCE LINES 449-492 .. code-block:: default def plot_alignments(waveform, emission, tokens, timesteps, sample_rate): t = torch.arange(waveform.size(0)) / sample_rate ratio = waveform.size(0) / emission.size(1) / sample_rate chars = [] words = [] word_start = None for token, timestep in zip(tokens, timesteps * ratio): if token == "|": if word_start is not None: words.append((word_start, timestep)) word_start = None else: chars.append((token, timestep)) if word_start is None: word_start = timestep fig, axes = plt.subplots(3, 1) def _plot(ax, xlim): ax.plot(t, waveform) for token, timestep in chars: ax.annotate(token.upper(), (timestep, 0.5)) for word_start, word_end in words: ax.axvspan(word_start, word_end, alpha=0.1, color="red") ax.set_ylim(-0.6, 0.7) ax.set_yticks([0]) ax.grid(True, axis="y") ax.set_xlim(xlim) _plot(axes[0], (0.3, 2.5)) _plot(axes[1], (2.5, 4.7)) _plot(axes[2], (4.7, 6.9)) axes[2].set_xlabel("time (sec)") fig.tight_layout() plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate) .. image-sg:: /tutorials/images/sphx_glr_asr_inference_with_ctc_decoder_tutorial_001.png :alt: asr inference with ctc decoder tutorial :srcset: /tutorials/images/sphx_glr_asr_inference_with_ctc_decoder_tutorial_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 493-501 Beam Search Decoder Parameters ------------------------------ In this section, we go a little bit more in depth about some different parameters and tradeoffs. For the full list of customizable parameters, please refer to the :py:func:`documentation `. .. GENERATED FROM PYTHON SOURCE LINES 504-507 Helper Function ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 507-519 .. code-block:: default def print_decoded(decoder, emission, param, param_value): start_time = time.monotonic() result = decoder(emission) decode_time = time.monotonic() - start_time transcript = " ".join(result[0][0].words).lower().strip() score = result[0][0].score print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)") .. GENERATED FROM PYTHON SOURCE LINES 520-528 nbest ~~~~~ This parameter indicates the number of best hypotheses to return, which is a property that is not possible with the greedy decoder. For instance, by setting ``nbest=3`` when constructing the beam search decoder earlier, we can now access the hypotheses with the top 3 scores. .. GENERATED FROM PYTHON SOURCE LINES 528-535 .. code-block:: default for i in range(3): transcript = " ".join(beam_search_result[0][i].words).strip() score = beam_search_result[0][i].score print(f"{transcript} (score: {score})") .. rst-class:: sphx-glr-script-out .. code-block:: none i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.824109642502) i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.858373688456) i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.0157600045172) .. GENERATED FROM PYTHON SOURCE LINES 536-550 beam size ~~~~~~~~~ The ``beam_size`` parameter determines the maximum number of best hypotheses to hold after each decoding step. Using larger beam sizes allows for exploring a larger range of possible hypotheses which can produce hypotheses with higher scores, but it is computationally more expensive and does not provide additional gains beyond a certain point. In the example below, we see improvement in decoding quality as we increase beam size from 1 to 5 to 50, but notice how using a beam size of 500 provides the same output as beam size 50 while increase the computation time. .. GENERATED FROM PYTHON SOURCE LINES 550-566 .. code-block:: default beam_sizes = [1, 5, 50, 500] for beam_size in beam_sizes: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_size=beam_size, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam size", beam_size) .. rst-class:: sphx-glr-script-out .. code-block:: none beam size 1 : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.0439 secs) beam size 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0498 secs) beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.1585 secs) beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.5238 secs) .. GENERATED FROM PYTHON SOURCE LINES 567-575 beam size token ~~~~~~~~~~~~~~~ The ``beam_size_token`` parameter corresponds to the number of tokens to consider for expanding each hypothesis at the decoding step. Exploring a larger number of next possible tokens increases the range of potential hypotheses at the cost of computation. .. GENERATED FROM PYTHON SOURCE LINES 575-592 .. code-block:: default num_tokens = len(tokens) beam_size_tokens = [1, 5, 10, num_tokens] for beam_size_token in beam_size_tokens: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_size_token=beam_size_token, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token) .. rst-class:: sphx-glr-script-out .. code-block:: none beam size token 1 : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.1512 secs) beam size token 5 : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.1711 secs) beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.1901 secs) beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2206 secs) .. GENERATED FROM PYTHON SOURCE LINES 593-603 beam threshold ~~~~~~~~~~~~~~ The ``beam_threshold`` parameter is used to prune the stored hypotheses set at each decoding step, removing hypotheses whose scores are greater than ``beam_threshold`` away from the highest scoring hypothesis. There is a balance between choosing smaller thresholds to prune more hypotheses and reduce the search space, and choosing a large enough threshold such that plausible hypotheses are not pruned. .. GENERATED FROM PYTHON SOURCE LINES 603-619 .. code-block:: default beam_thresholds = [1, 5, 10, 25] for beam_threshold in beam_thresholds: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, beam_threshold=beam_threshold, lm_weight=LM_WEIGHT, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold) .. rst-class:: sphx-glr-script-out .. code-block:: none beam threshold 1 : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.0263 secs) beam threshold 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.0491 secs) beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2012 secs) beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2237 secs) .. GENERATED FROM PYTHON SOURCE LINES 620-629 language model weight ~~~~~~~~~~~~~~~~~~~~~ The ``lm_weight`` parameter is the weight to assign to the language model score which to accumulate with the acoustic model score for determining the overall scores. Larger weights encourage the model to predict next words based on the language model, while smaller weights give more weight to the acoustic model score instead. .. GENERATED FROM PYTHON SOURCE LINES 629-644 .. code-block:: default lm_weights = [0, LM_WEIGHT, 15] for lm_weight in lm_weights: beam_search_decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, lm_weight=lm_weight, word_score=WORD_SCORE, ) print_decoded(beam_search_decoder, emission, "lm weight", lm_weight) .. rst-class:: sphx-glr-script-out .. code-block:: none lm weight 0 : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.2452 secs) lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2519 secs) lm weight 15 : was there in his was at some of what he said (score: 2918.99; 0.2304 secs) .. GENERATED FROM PYTHON SOURCE LINES 645-655 additional parameters ~~~~~~~~~~~~~~~~~~~~~ Additional parameters that can be optimized include the following - ``word_score``: score to add when word finishes - ``unk_score``: unknown word appearance score to add - ``sil_score``: silence appearance score to add - ``log_add``: whether to use log add for lexicon Trie smearing .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 51.586 seconds) .. _sphx_glr_download_tutorials_asr_inference_with_ctc_decoder_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: asr_inference_with_ctc_decoder_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: asr_inference_with_ctc_decoder_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_