• Docs >
  • ASR Inference with CTC Decoder
Shortcuts

ASR Inference with CTC Decoder

Author: Caroline Chen

This tutorial shows how to perform speech recognition inference using a CTC beam search decoder with lexicon constraint and KenLM language model support. We demonstrate this on a pretrained wav2vec 2.0 model trained using CTC loss.

Overview

Beam search decoding works by iteratively expanding text hypotheses (beams) with next possible characters, and maintaining only the hypotheses with the highest scores at each time step. A language model can be incorporated into the scoring computation, and adding a lexicon constraint restricts the next possible tokens for the hypotheses so that only words from the lexicon can be generated.

The underlying implementation is ported from Flashlight’s beam search decoder. A mathematical formula for the decoder optimization can be found in the Wav2Letter paper, and a more detailed algorithm can be found in this blog.

Running ASR inference using a CTC Beam Search decoder with a KenLM language model and lexicon constraint requires the following components

  • Acoustic Model: model predicting phonetics from audio waveforms

  • Tokens: the possible predicted tokens from the acoustic model

  • Lexicon: mapping between possible words and their corresponding tokens sequence

  • KenLM: n-gram language model trained with the KenLM library

Preparation

First we import the necessary utilities and fetch the data that we are working with

import time
from typing import List

import IPython
import matplotlib.pyplot as plt
import torch
import torchaudio

try:
    from torchaudio.models.decoder import ctc_decoder
except ModuleNotFoundError:
    try:
        import google.colab

        print(
            """
            To enable running this notebook in Google Colab, install nightly
            torch and torchaudio builds by adding the following code block to the top
            of the notebook before running it:

            !pip3 uninstall -y torch torchvision torchaudio
            !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
            """
        )
    except ModuleNotFoundError:
        pass
    raise

Acoustic Model and Data

We use the pretrained Wav2Vec 2.0 Base model that is finetuned on 10 min of the LibriSpeech dataset, which can be loaded in using torchaudio.pipelines(). For more detail on running Wav2Vec 2.0 speech recognition pipelines in torchaudio, please refer to this tutorial.

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
acoustic_model = bundle.get_model()

Out:

Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth

  0%|          | 0.00/360M [00:00<?, ?B/s]
  1%|          | 2.18M/360M [00:00<00:20, 18.2MB/s]
  4%|4         | 15.5M/360M [00:00<00:04, 82.7MB/s]
  7%|6         | 24.0M/360M [00:00<00:05, 67.0MB/s]
  9%|8         | 32.0M/360M [00:00<00:05, 62.8MB/s]
 13%|#3        | 48.0M/360M [00:00<00:03, 92.8MB/s]
 17%|#7        | 62.8M/360M [00:00<00:02, 108MB/s]
 21%|##        | 73.9M/360M [00:00<00:03, 87.6MB/s]
 23%|##3       | 83.3M/360M [00:01<00:04, 58.4MB/s]
 29%|##8       | 104M/360M [00:01<00:03, 87.1MB/s]
 32%|###1      | 115M/360M [00:01<00:02, 90.4MB/s]
 36%|###5      | 128M/360M [00:01<00:02, 84.9MB/s]
 40%|###9      | 144M/360M [00:01<00:02, 99.2MB/s]
 43%|####3     | 155M/360M [00:01<00:02, 100MB/s]
 46%|####5     | 166M/360M [00:02<00:02, 84.6MB/s]
 49%|####8     | 176M/360M [00:02<00:02, 79.1MB/s]
 53%|#####3    | 192M/360M [00:02<00:01, 94.1MB/s]
 56%|#####6    | 202M/360M [00:02<00:02, 71.6MB/s]
 58%|#####8    | 210M/360M [00:02<00:02, 75.3MB/s]
 62%|######2   | 224M/360M [00:02<00:01, 79.9MB/s]
 67%|######6   | 240M/360M [00:03<00:01, 91.8MB/s]
 69%|######9   | 249M/360M [00:03<00:01, 73.5MB/s]
 71%|#######1  | 257M/360M [00:03<00:01, 75.6MB/s]
 75%|#######5  | 271M/360M [00:03<00:01, 86.5MB/s]
 78%|#######7  | 280M/360M [00:03<00:01, 84.0MB/s]
 81%|########  | 290M/360M [00:03<00:00, 90.2MB/s]
 84%|########4 | 304M/360M [00:03<00:00, 94.1MB/s]
 87%|########6 | 313M/360M [00:04<00:00, 69.3MB/s]
 91%|#########1| 328M/360M [00:04<00:00, 87.5MB/s]
 98%|#########7| 352M/360M [00:04<00:00, 124MB/s]
100%|##########| 360M/360M [00:04<00:00, 87.6MB/s]

We will load a sample from the LibriSpeech test-other dataset.

hub_dir = torch.hub.get_dir()

speech_url = "https://download.pytorch.org/torchaudio/tutorial-assets/ctc-decoding/1688-142285-0007.wav"
speech_file = f"{hub_dir}/speech.wav"

torch.hub.download_url_to_file(speech_url, speech_file)

IPython.display.Audio(speech_file)

Out:

  0%|          | 0.00/441k [00:00<?, ?B/s]
100%|##########| 441k/441k [00:00<00:00, 7.35MB/s]


The transcript corresponding to this audio file is

i really was very much afraid of showing him how much shocked i was at some parts of what he said
waveform, sample_rate = torchaudio.load(speech_file)

if sample_rate != bundle.sample_rate:
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)

Files and Data for Decoder

Next, we load in our token, lexicon, and KenLM data, which are used by the decoder to predict words from the acoustic model output. Pretrained files for the LibriSpeech dataset can be downloaded through torchaudio, or the user can provide their own files.

Tokens

The tokens are the possible symbols that the acoustic model can predict, including the blank and silent symbols. It can either be passed in as a file, where each line consists of the tokens corresponding to the same index, or as a list of tokens, each mapping to a unique index.

# tokens.txt
_
|
e
t
...
tokens = [label.lower() for label in bundle.get_labels()]
print(tokens)

Out:

['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z']

Lexicon

The lexicon is a mapping from words to their corresponding tokens sequence, and is used to restrict the search space of the decoder to only words from the lexicon. The expected format of the lexicon file is a line per word, with a word followed by its space-split tokens.

# lexcion.txt
a a |
able a b l e |
about a b o u t |
...
...

KenLM

This is an n-gram language model trained with the KenLM library. Both the .arpa or the binarized .bin LM can be used, but the binary format is recommended for faster loading.

The language model used in this tutorial is a 4-gram KenLM trained using LibriSpeech.

Downloading Pretrained Files

Pretrained files for the LibriSpeech dataset can be downloaded using download_pretrained_files.

Note: this cell may take a couple of minutes to run, as the language model can be large

from torchaudio.models.decoder import download_pretrained_files

files = download_pretrained_files("librispeech-4-gram")

print(files)

Out:

  0%|          | 0.00/4.97M [00:00<?, ?B/s]
100%|##########| 4.97M/4.97M [00:00<00:00, 69.4MB/s]

  0%|          | 0.00/57.0 [00:00<?, ?B/s]
100%|##########| 57.0/57.0 [00:00<00:00, 41.5kB/s]

  0%|          | 0.00/2.91G [00:00<?, ?B/s]
  0%|          | 4.37M/2.91G [00:00<01:08, 45.7MB/s]
  0%|          | 8.73M/2.91G [00:00<01:09, 44.6MB/s]
  1%|          | 23.7M/2.91G [00:00<00:32, 95.2MB/s]
  1%|1         | 32.9M/2.91G [00:00<01:20, 38.4MB/s]
  2%|1         | 46.8M/2.91G [00:01<01:05, 46.9MB/s]
  2%|1         | 53.0M/2.91G [00:01<01:19, 38.4MB/s]
  2%|1         | 57.9M/2.91G [00:01<01:26, 35.3MB/s]
  2%|2         | 62.0M/2.91G [00:01<01:25, 35.9MB/s]
  2%|2         | 66.0M/2.91G [00:01<01:45, 29.0MB/s]
  3%|2         | 79.7M/2.91G [00:02<01:12, 42.0MB/s]
  3%|2         | 84.1M/2.91G [00:02<01:20, 37.5MB/s]
  3%|3         | 93.2M/2.91G [00:02<01:03, 48.0MB/s]
  3%|3         | 98.6M/2.91G [00:02<01:01, 49.2MB/s]
  3%|3         | 104M/2.91G [00:02<01:03, 47.7MB/s]
  4%|4         | 127M/2.91G [00:02<00:35, 84.0MB/s]
  5%|4         | 135M/2.91G [00:02<00:45, 65.4MB/s]
  5%|4         | 144M/2.91G [00:03<00:45, 64.8MB/s]
  5%|5         | 150M/2.91G [00:03<00:49, 60.4MB/s]
  5%|5         | 160M/2.91G [00:03<00:49, 59.6MB/s]
  6%|5         | 168M/2.91G [00:03<00:45, 64.9MB/s]
  6%|5         | 176M/2.91G [00:03<00:49, 59.6MB/s]
  6%|6         | 182M/2.91G [00:03<00:57, 51.2MB/s]
  7%|6         | 202M/2.91G [00:03<00:33, 85.9MB/s]
  7%|7         | 212M/2.91G [00:04<00:36, 80.4MB/s]
  8%|7         | 224M/2.91G [00:04<00:33, 85.5MB/s]
  8%|7         | 233M/2.91G [00:04<00:37, 76.6MB/s]
  8%|8         | 241M/2.91G [00:04<00:42, 68.4MB/s]
  9%|8         | 255M/2.91G [00:04<00:38, 73.3MB/s]
  9%|8         | 262M/2.91G [00:04<00:45, 62.3MB/s]
  9%|9         | 268M/2.91G [00:05<00:49, 57.1MB/s]
  9%|9         | 281M/2.91G [00:05<00:42, 67.1MB/s]
 10%|9         | 288M/2.91G [00:05<00:43, 65.3MB/s]
 10%|9         | 295M/2.91G [00:05<00:49, 56.5MB/s]
 10%|#         | 300M/2.91G [00:05<00:52, 53.4MB/s]
 10%|#         | 306M/2.91G [00:06<01:28, 31.6MB/s]
 10%|#         | 310M/2.91G [00:06<01:38, 28.4MB/s]
 11%|#         | 320M/2.91G [00:06<01:09, 40.3MB/s]
 11%|#         | 327M/2.91G [00:06<00:59, 46.7MB/s]
 11%|#1        | 333M/2.91G [00:06<01:02, 44.4MB/s]
 11%|#1        | 338M/2.91G [00:07<01:41, 27.4MB/s]
 12%|#1        | 352M/2.91G [00:07<01:07, 40.6MB/s]
 12%|#1        | 357M/2.91G [00:07<01:08, 40.0MB/s]
 12%|#2        | 362M/2.91G [00:07<01:09, 39.3MB/s]
 13%|#2        | 382M/2.91G [00:07<00:37, 72.0MB/s]
 13%|#3        | 398M/2.91G [00:07<00:28, 93.7MB/s]
 14%|#3        | 410M/2.91G [00:07<00:27, 96.4MB/s]
 14%|#4        | 421M/2.91G [00:08<00:41, 64.8MB/s]
 14%|#4        | 432M/2.91G [00:08<00:38, 69.4MB/s]
 15%|#5        | 448M/2.91G [00:08<00:34, 77.5MB/s]
 15%|#5        | 456M/2.91G [00:08<00:40, 65.4MB/s]
 16%|#5        | 464M/2.91G [00:08<00:56, 46.7MB/s]
 16%|#5        | 469M/2.91G [00:09<00:58, 45.1MB/s]
 16%|#6        | 480M/2.91G [00:09<00:53, 48.5MB/s]
 17%|#6        | 496M/2.91G [00:09<00:40, 64.0MB/s]
 17%|#7        | 512M/2.91G [00:09<00:31, 82.1MB/s]
 18%|#7        | 527M/2.91G [00:09<00:29, 86.2MB/s]
 18%|#7        | 536M/2.91G [00:09<00:31, 81.9MB/s]
 18%|#8        | 545M/2.91G [00:10<00:38, 65.9MB/s]
 19%|#8        | 560M/2.91G [00:10<00:36, 69.8MB/s]
 19%|#9        | 572M/2.91G [00:10<00:31, 81.0MB/s]
 20%|#9        | 581M/2.91G [00:10<00:34, 72.2MB/s]
 20%|#9        | 591M/2.91G [00:10<00:32, 77.2MB/s]
 20%|##        | 599M/2.91G [00:10<00:37, 66.6MB/s]
 20%|##        | 607M/2.91G [00:11<00:35, 69.2MB/s]
 21%|##        | 614M/2.91G [00:11<00:41, 60.3MB/s]
 21%|##        | 623M/2.91G [00:11<00:37, 65.5MB/s]
 21%|##1       | 630M/2.91G [00:11<00:38, 64.7MB/s]
 21%|##1       | 640M/2.91G [00:11<00:37, 65.7MB/s]
 22%|##1       | 646M/2.91G [00:11<00:46, 52.1MB/s]
 22%|##1       | 652M/2.91G [00:12<01:04, 38.1MB/s]
 22%|##2       | 656M/2.91G [00:12<01:24, 28.8MB/s]
 23%|##2       | 672M/2.91G [00:12<00:51, 46.8MB/s]
 23%|##2       | 684M/2.91G [00:12<00:39, 61.1MB/s]
 23%|##3       | 692M/2.91G [00:12<00:50, 47.5MB/s]
 24%|##3       | 702M/2.91G [00:13<00:42, 56.3MB/s]
 24%|##3       | 709M/2.91G [00:13<00:44, 53.2MB/s]
 24%|##4       | 720M/2.91G [00:13<00:40, 59.2MB/s]
 24%|##4       | 727M/2.91G [00:13<00:46, 51.3MB/s]
 25%|##4       | 732M/2.91G [00:13<00:51, 45.8MB/s]
 25%|##4       | 737M/2.91G [00:14<01:20, 29.1MB/s]
 25%|##4       | 745M/2.91G [00:14<01:03, 36.9MB/s]
 25%|##5       | 752M/2.91G [00:14<00:57, 40.4MB/s]
 26%|##5       | 764M/2.91G [00:14<00:41, 56.3MB/s]
 26%|##5       | 771M/2.91G [00:14<00:46, 49.6MB/s]
 26%|##6       | 784M/2.91G [00:14<00:38, 59.1MB/s]
 27%|##6       | 791M/2.91G [00:14<00:38, 60.1MB/s]
 27%|##6       | 800M/2.91G [00:15<00:36, 62.3MB/s]
 27%|##7       | 815M/2.91G [00:15<00:32, 70.1MB/s]
 28%|##7       | 822M/2.91G [00:15<00:36, 62.0MB/s]
 28%|##7       | 832M/2.91G [00:15<00:34, 64.4MB/s]
 28%|##8       | 838M/2.91G [00:15<00:37, 59.6MB/s]
 28%|##8       | 844M/2.91G [00:15<00:42, 53.1MB/s]
 29%|##8       | 849M/2.91G [00:15<00:41, 53.9MB/s]
 29%|##9       | 864M/2.91G [00:16<00:28, 78.2MB/s]
 30%|##9       | 880M/2.91G [00:16<00:22, 98.9MB/s]
 30%|##9       | 890M/2.91G [00:16<00:37, 58.8MB/s]
 30%|###       | 898M/2.91G [00:16<00:43, 50.0MB/s]
 30%|###       | 906M/2.91G [00:16<00:38, 56.4MB/s]
 31%|###       | 914M/2.91G [00:17<00:40, 53.9MB/s]
 31%|###       | 920M/2.91G [00:17<00:46, 46.4MB/s]
 32%|###1      | 944M/2.91G [00:17<00:28, 73.7MB/s]
 32%|###2      | 960M/2.91G [00:17<00:28, 74.8MB/s]
 32%|###2      | 967M/2.91G [00:17<00:31, 67.5MB/s]
 33%|###2      | 976M/2.91G [00:17<00:32, 65.5MB/s]
 33%|###3      | 987M/2.91G [00:18<00:27, 74.9MB/s]
 33%|###3      | 995M/2.91G [00:18<00:28, 74.3MB/s]
 34%|###3      | 0.98G/2.91G [00:18<00:25, 81.4MB/s]
 34%|###4      | 0.99G/2.91G [00:18<00:35, 57.5MB/s]
 35%|###4      | 1.01G/2.91G [00:18<00:27, 75.6MB/s]
 35%|###4      | 1.02G/2.91G [00:18<00:26, 76.5MB/s]
 35%|###5      | 1.02G/2.91G [00:18<00:27, 72.6MB/s]
 35%|###5      | 1.03G/2.91G [00:19<00:28, 70.6MB/s]
 36%|###5      | 1.04G/2.91G [00:19<00:25, 80.2MB/s]
 36%|###6      | 1.05G/2.91G [00:19<00:27, 73.1MB/s]
 36%|###6      | 1.06G/2.91G [00:19<00:32, 60.7MB/s]
 37%|###6      | 1.06G/2.91G [00:19<00:48, 40.7MB/s]
 37%|###6      | 1.07G/2.91G [00:20<00:51, 38.6MB/s]
 37%|###7      | 1.08G/2.91G [00:20<00:41, 47.6MB/s]
 38%|###7      | 1.09G/2.91G [00:20<00:28, 68.5MB/s]
 38%|###7      | 1.10G/2.91G [00:20<00:29, 67.0MB/s]
 38%|###8      | 1.11G/2.91G [00:20<00:31, 60.9MB/s]
 38%|###8      | 1.12G/2.91G [00:20<00:29, 65.5MB/s]
 39%|###8      | 1.12G/2.91G [00:20<00:31, 61.8MB/s]
 39%|###8      | 1.13G/2.91G [00:21<00:36, 52.6MB/s]
 39%|###9      | 1.14G/2.91G [00:21<00:34, 55.0MB/s]
 39%|###9      | 1.15G/2.91G [00:21<00:46, 40.5MB/s]
 40%|####      | 1.17G/2.91G [00:21<00:22, 82.9MB/s]
 41%|####      | 1.19G/2.91G [00:21<00:19, 95.3MB/s]
 41%|####1     | 1.20G/2.91G [00:21<00:20, 91.6MB/s]
 42%|####1     | 1.21G/2.91G [00:22<00:20, 89.8MB/s]
 42%|####1     | 1.22G/2.91G [00:22<00:20, 86.6MB/s]
 42%|####2     | 1.23G/2.91G [00:22<00:20, 87.8MB/s]
 43%|####2     | 1.25G/2.91G [00:22<00:20, 87.4MB/s]
 43%|####3     | 1.26G/2.91G [00:22<00:30, 58.9MB/s]
 44%|####3     | 1.27G/2.91G [00:22<00:26, 66.8MB/s]
 44%|####3     | 1.28G/2.91G [00:23<00:26, 65.0MB/s]
 44%|####4     | 1.28G/2.91G [00:23<00:31, 54.8MB/s]
 44%|####4     | 1.29G/2.91G [00:23<00:28, 60.2MB/s]
 45%|####5     | 1.31G/2.91G [00:23<00:19, 89.4MB/s]
 45%|####5     | 1.32G/2.91G [00:23<00:20, 81.3MB/s]
 46%|####5     | 1.33G/2.91G [00:23<00:20, 81.9MB/s]
 46%|####6     | 1.34G/2.91G [00:23<00:17, 95.5MB/s]
 46%|####6     | 1.35G/2.91G [00:24<00:20, 80.6MB/s]
 47%|####6     | 1.36G/2.91G [00:24<00:21, 76.7MB/s]
 47%|####7     | 1.37G/2.91G [00:24<00:35, 46.2MB/s]
 47%|####7     | 1.37G/2.91G [00:24<00:34, 47.3MB/s]
 48%|####8     | 1.40G/2.91G [00:24<00:19, 84.6MB/s]
 48%|####8     | 1.41G/2.91G [00:24<00:17, 92.5MB/s]
 49%|####8     | 1.42G/2.91G [00:25<00:16, 96.6MB/s]
 49%|####9     | 1.44G/2.91G [00:25<00:16, 96.9MB/s]
 50%|####9     | 1.45G/2.91G [00:25<00:14, 111MB/s]
 50%|#####     | 1.46G/2.91G [00:25<00:13, 117MB/s]
 51%|#####     | 1.48G/2.91G [00:25<00:18, 83.4MB/s]
 51%|#####1    | 1.49G/2.91G [00:26<00:30, 50.8MB/s]
 52%|#####1    | 1.50G/2.91G [00:26<00:24, 63.0MB/s]
 52%|#####2    | 1.52G/2.91G [00:26<00:18, 81.1MB/s]
 53%|#####2    | 1.53G/2.91G [00:26<00:19, 75.8MB/s]
 53%|#####2    | 1.54G/2.91G [00:26<00:23, 64.0MB/s]
 53%|#####3    | 1.55G/2.91G [00:26<00:21, 68.4MB/s]
 53%|#####3    | 1.56G/2.91G [00:27<00:19, 75.9MB/s]
 54%|#####3    | 1.57G/2.91G [00:27<00:17, 81.5MB/s]
 54%|#####4    | 1.58G/2.91G [00:27<00:21, 65.4MB/s]
 54%|#####4    | 1.58G/2.91G [00:27<00:26, 53.4MB/s]
 55%|#####4    | 1.59G/2.91G [00:27<00:25, 54.9MB/s]
 55%|#####4    | 1.60G/2.91G [00:27<00:26, 52.2MB/s]
 55%|#####5    | 1.61G/2.91G [00:28<00:21, 64.0MB/s]
 56%|#####5    | 1.62G/2.91G [00:28<00:26, 52.8MB/s]
 56%|#####5    | 1.62G/2.91G [00:28<00:28, 48.1MB/s]
 56%|#####5    | 1.63G/2.91G [00:28<00:38, 35.4MB/s]
 56%|#####6    | 1.64G/2.91G [00:28<00:26, 50.8MB/s]
 57%|#####6    | 1.65G/2.91G [00:28<00:27, 50.1MB/s]
 57%|#####6    | 1.66G/2.91G [00:29<00:25, 53.3MB/s]
 57%|#####7    | 1.66G/2.91G [00:29<00:31, 43.1MB/s]
 57%|#####7    | 1.67G/2.91G [00:29<00:25, 53.1MB/s]
 58%|#####7    | 1.68G/2.91G [00:29<00:31, 42.1MB/s]
 58%|#####7    | 1.68G/2.91G [00:29<00:35, 36.9MB/s]
 59%|#####8    | 1.70G/2.91G [00:30<00:18, 69.3MB/s]
 59%|#####9    | 1.72G/2.91G [00:30<00:15, 84.9MB/s]
 59%|#####9    | 1.73G/2.91G [00:30<00:16, 76.2MB/s]
 60%|#####9    | 1.74G/2.91G [00:30<00:19, 63.8MB/s]
 60%|#####9    | 1.74G/2.91G [00:30<00:17, 69.5MB/s]
 60%|######    | 1.75G/2.91G [00:31<00:27, 45.7MB/s]
 60%|######    | 1.76G/2.91G [00:31<00:30, 40.3MB/s]
 61%|######1   | 1.78G/2.91G [00:31<00:18, 66.7MB/s]
 62%|######1   | 1.79G/2.91G [00:31<00:15, 77.7MB/s]
 62%|######1   | 1.80G/2.91G [00:31<00:19, 61.4MB/s]
 62%|######2   | 1.81G/2.91G [00:31<00:19, 61.3MB/s]
 63%|######2   | 1.82G/2.91G [00:32<00:24, 47.4MB/s]
 63%|######2   | 1.83G/2.91G [00:32<00:25, 45.4MB/s]
 63%|######2   | 1.83G/2.91G [00:32<00:35, 32.3MB/s]
 63%|######3   | 1.84G/2.91G [00:32<00:30, 38.2MB/s]
 64%|######3   | 1.85G/2.91G [00:33<00:22, 50.3MB/s]
 64%|######3   | 1.86G/2.91G [00:33<00:19, 57.8MB/s]
 64%|######4   | 1.87G/2.91G [00:33<00:20, 55.8MB/s]
 64%|######4   | 1.87G/2.91G [00:33<00:20, 54.3MB/s]
 65%|######4   | 1.88G/2.91G [00:33<00:20, 54.2MB/s]
 65%|######5   | 1.89G/2.91G [00:33<00:14, 77.6MB/s]
 66%|######5   | 1.91G/2.91G [00:33<00:09, 112MB/s]
 66%|######6   | 1.92G/2.91G [00:33<00:11, 92.4MB/s]
 67%|######6   | 1.94G/2.91G [00:34<00:10, 95.6MB/s]
 67%|######6   | 1.95G/2.91G [00:34<00:10, 102MB/s]
 67%|######7   | 1.96G/2.91G [00:34<00:10, 99.3MB/s]
 68%|######7   | 1.97G/2.91G [00:34<00:10, 95.6MB/s]
 68%|######7   | 1.98G/2.91G [00:34<00:11, 86.7MB/s]
 68%|######8   | 1.99G/2.91G [00:34<00:10, 92.1MB/s]
 69%|######8   | 2.00G/2.91G [00:34<00:14, 69.9MB/s]
 69%|######8   | 2.00G/2.91G [00:35<00:15, 61.5MB/s]
 69%|######9   | 2.02G/2.91G [00:35<00:14, 67.1MB/s]
 70%|######9   | 2.03G/2.91G [00:35<00:14, 65.5MB/s]
 70%|######9   | 2.03G/2.91G [00:35<00:19, 49.0MB/s]
 70%|#######   | 2.04G/2.91G [00:35<00:20, 45.9MB/s]
 71%|#######   | 2.05G/2.91G [00:35<00:14, 62.3MB/s]
 71%|#######   | 2.06G/2.91G [00:36<00:16, 56.3MB/s]
 71%|#######   | 2.07G/2.91G [00:36<00:19, 46.0MB/s]
 71%|#######1  | 2.07G/2.91G [00:36<00:16, 55.6MB/s]
 71%|#######1  | 2.08G/2.91G [00:36<00:27, 32.9MB/s]
 72%|#######1  | 2.08G/2.91G [00:37<00:25, 34.3MB/s]
 72%|#######2  | 2.10G/2.91G [00:37<00:18, 48.2MB/s]
 72%|#######2  | 2.11G/2.91G [00:37<00:14, 60.5MB/s]
 73%|#######2  | 2.12G/2.91G [00:37<00:11, 76.0MB/s]
 73%|#######3  | 2.13G/2.91G [00:37<00:14, 56.7MB/s]
 73%|#######3  | 2.14G/2.91G [00:37<00:15, 54.7MB/s]
 74%|#######3  | 2.14G/2.91G [00:37<00:15, 52.8MB/s]
 74%|#######4  | 2.16G/2.91G [00:38<00:13, 61.5MB/s]
 74%|#######4  | 2.16G/2.91G [00:38<00:12, 64.0MB/s]
 75%|#######4  | 2.17G/2.91G [00:38<00:11, 68.8MB/s]
 75%|#######4  | 2.18G/2.91G [00:38<00:11, 70.0MB/s]
 75%|#######5  | 2.19G/2.91G [00:38<00:13, 58.3MB/s]
 76%|#######5  | 2.20G/2.91G [00:38<00:10, 72.6MB/s]
 76%|#######6  | 2.22G/2.91G [00:38<00:07, 98.6MB/s]
 77%|#######6  | 2.23G/2.91G [00:39<00:09, 75.4MB/s]
 77%|#######6  | 2.24G/2.91G [00:39<00:11, 64.2MB/s]
 77%|#######7  | 2.25G/2.91G [00:39<00:08, 80.3MB/s]
 78%|#######7  | 2.26G/2.91G [00:39<00:08, 85.0MB/s]
 78%|#######7  | 2.27G/2.91G [00:39<00:09, 76.4MB/s]
 78%|#######8  | 2.28G/2.91G [00:39<00:09, 74.9MB/s]
 79%|#######8  | 2.30G/2.91G [00:40<00:07, 83.4MB/s]
 79%|#######9  | 2.31G/2.91G [00:40<00:06, 95.1MB/s]
 80%|#######9  | 2.32G/2.91G [00:40<00:12, 51.4MB/s]
 80%|########  | 2.33G/2.91G [00:40<00:10, 58.7MB/s]
 81%|########  | 2.34G/2.91G [00:40<00:08, 68.9MB/s]
 81%|########  | 2.35G/2.91G [00:41<00:08, 68.8MB/s]
 81%|########1 | 2.36G/2.91G [00:41<00:10, 58.4MB/s]
 82%|########1 | 2.37G/2.91G [00:41<00:07, 75.3MB/s]
 82%|########1 | 2.38G/2.91G [00:41<00:07, 73.2MB/s]
 82%|########2 | 2.39G/2.91G [00:41<00:10, 53.8MB/s]
 82%|########2 | 2.40G/2.91G [00:41<00:11, 47.8MB/s]
 83%|########2 | 2.40G/2.91G [00:42<00:12, 42.3MB/s]
 83%|########2 | 2.41G/2.91G [00:42<00:14, 38.5MB/s]
 83%|########3 | 2.42G/2.91G [00:42<00:08, 58.5MB/s]
 84%|########3 | 2.43G/2.91G [00:42<00:07, 72.6MB/s]
 84%|########3 | 2.44G/2.91G [00:42<00:06, 78.8MB/s]
 84%|########4 | 2.45G/2.91G [00:42<00:05, 87.3MB/s]
 85%|########4 | 2.46G/2.91G [00:42<00:05, 80.8MB/s]
 85%|########4 | 2.47G/2.91G [00:43<00:08, 54.1MB/s]
 85%|########5 | 2.49G/2.91G [00:43<00:06, 71.8MB/s]
 86%|########5 | 2.50G/2.91G [00:43<00:05, 82.1MB/s]
 86%|########6 | 2.51G/2.91G [00:43<00:06, 71.9MB/s]
 86%|########6 | 2.51G/2.91G [00:43<00:06, 60.9MB/s]
 87%|########6 | 2.52G/2.91G [00:43<00:05, 71.1MB/s]
 87%|########7 | 2.53G/2.91G [00:44<00:06, 66.9MB/s]
 88%|########7 | 2.55G/2.91G [00:44<00:05, 70.8MB/s]
 88%|########8 | 2.56G/2.91G [00:44<00:04, 82.1MB/s]
 88%|########8 | 2.57G/2.91G [00:44<00:05, 71.5MB/s]
 89%|########8 | 2.58G/2.91G [00:44<00:06, 58.9MB/s]
 89%|########9 | 2.59G/2.91G [00:44<00:04, 77.1MB/s]
 90%|########9 | 2.61G/2.91G [00:45<00:03, 97.3MB/s]
 90%|######### | 2.62G/2.91G [00:45<00:02, 111MB/s]
 91%|######### | 2.63G/2.91G [00:45<00:03, 80.0MB/s]
 91%|######### | 2.64G/2.91G [00:45<00:04, 71.0MB/s]
 91%|#########1| 2.66G/2.91G [00:45<00:03, 71.9MB/s]
 92%|#########1| 2.67G/2.91G [00:45<00:02, 87.3MB/s]
 92%|#########2| 2.68G/2.91G [00:46<00:04, 57.4MB/s]
 92%|#########2| 2.69G/2.91G [00:46<00:03, 61.0MB/s]
 93%|#########2| 2.70G/2.91G [00:46<00:02, 79.6MB/s]
 93%|#########3| 2.71G/2.91G [00:46<00:02, 79.9MB/s]
 94%|#########3| 2.72G/2.91G [00:46<00:03, 65.3MB/s]
 94%|#########3| 2.73G/2.91G [00:46<00:02, 73.6MB/s]
 94%|#########4| 2.74G/2.91G [00:47<00:02, 61.1MB/s]
 95%|#########4| 2.75G/2.91G [00:47<00:03, 54.7MB/s]
 95%|#########4| 2.76G/2.91G [00:47<00:02, 55.5MB/s]
 95%|#########4| 2.76G/2.91G [00:47<00:03, 52.8MB/s]
 96%|#########5| 2.78G/2.91G [00:47<00:01, 80.1MB/s]
 96%|#########5| 2.79G/2.91G [00:47<00:01, 80.6MB/s]
 96%|#########6| 2.80G/2.91G [00:47<00:01, 74.1MB/s]
 96%|#########6| 2.81G/2.91G [00:48<00:01, 80.6MB/s]
 97%|#########6| 2.81G/2.91G [00:48<00:01, 55.8MB/s]
 97%|#########6| 2.82G/2.91G [00:48<00:01, 57.5MB/s]
 97%|#########7| 2.83G/2.91G [00:48<00:01, 53.5MB/s]
 97%|#########7| 2.83G/2.91G [00:48<00:02, 39.8MB/s]
 98%|#########7| 2.84G/2.91G [00:49<00:01, 45.3MB/s]
 98%|#########7| 2.85G/2.91G [00:49<00:01, 44.5MB/s]
 98%|#########8| 2.86G/2.91G [00:49<00:00, 54.8MB/s]
 99%|#########8| 2.87G/2.91G [00:49<00:00, 66.1MB/s]
 99%|#########9| 2.89G/2.91G [00:49<00:00, 69.4MB/s]
100%|#########9| 2.91G/2.91G [00:49<00:00, 82.5MB/s]
100%|##########| 2.91G/2.91G [00:50<00:00, 62.4MB/s]
PretrainedFiles(lexicon='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt', tokens='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt', lm='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin')

Construct Decoders

In this tutorial, we construct both a beam search decoder and a greedy decoder for comparison.

Beam Search Decoder

The decoder can be constructed using the factory function ctc_decoder. In addition to the previously mentioned components, it also takes in various beam search decoding parameters and token/word parameters.

This decoder can also be run without a language model by passing in None into the lm parameter.

from torchaudio.models.decoder import ctc_decoder

LM_WEIGHT = 3.23
WORD_SCORE = -0.26

beam_search_decoder = ctc_decoder(
    lexicon=files.lexicon,
    tokens=files.tokens,
    lm=files.lm,
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
)

Greedy Decoder

class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor) -> List[str]:
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()


greedy_decoder = GreedyCTCDecoder(tokens)

Run Inference

Now that we have the data, acoustic model, and decoder, we can perform inference. The output of the beam search decoder is of type torchaudio.models.decoder.CTCHypothesis(), consisting of the predicted token IDs, corresponding words, hypothesis score, and timesteps corresponding to the token IDs. Recall the transcript corresponding to the waveform is

i really was very much afraid of showing him how much shocked i was at some parts of what he said
actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = actual_transcript.split()

emission, _ = acoustic_model(waveform)

The greedy decoder give the following result.

greedy_result = greedy_decoder(emission[0])
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)

print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")

Out:

Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid
WER: 0.38095238095238093

Using the beam search decoder:

beam_search_result = beam_search_decoder(emission)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
    actual_transcript
)

print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer}")

Out:

Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said
WER: 0.047619047619047616

We see that the transcript with the lexicon-constrained beam search decoder produces a more accurate result consisting of real words, while the greedy decoder can predict incorrectly spelled words like “affrayd” and “shoktd”.

Timestep Alignments

Recall that one of the components of the resulting Hypotheses is timesteps corresponding to the token IDs.

timesteps = beam_search_result[0][0].timesteps
predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)

print(predicted_tokens, len(predicted_tokens))
print(timesteps, timesteps.shape[0])

Out:

['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99
tensor([  0,  31,  33,  36,  39,  41,  42,  44,  46,  48,  49,  52,  54,  58,
         64,  66,  69,  73,  74,  76,  80,  82,  84,  86,  88,  94,  97, 107,
        111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157,
        159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191,
        193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250,
        251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289,
        292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331,
        353], dtype=torch.int32) 99

Below, we visualize the token timestep alignments relative to the original waveform.

def plot_alignments(waveform, emission, tokens, timesteps):
    fig, ax = plt.subplots(figsize=(32, 10))

    ax.plot(waveform)

    ratio = waveform.shape[0] / emission.shape[1]
    word_start = 0

    for i in range(len(tokens)):
        if i != 0 and tokens[i - 1] == "|":
            word_start = timesteps[i]
        if tokens[i] != "|":
            plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
        elif i != 0:
            word_end = timesteps[i]
            ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")

    xticks = ax.get_xticks()
    plt.xticks(xticks, xticks / bundle.sample_rate)
    ax.set_xlabel("time (sec)")
    ax.set_xlim(0, waveform.shape[0])


plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
asr inference with ctc decoder tutorial

Beam Search Decoder Parameters

In this section, we go a little bit more in depth about some different parameters and tradeoffs. For the full list of customizable parameters, please refer to the documentation.

Helper Function

def print_decoded(decoder, emission, param, param_value):
    start_time = time.monotonic()
    result = decoder(emission)
    decode_time = time.monotonic() - start_time

    transcript = " ".join(result[0][0].words).lower().strip()
    score = result[0][0].score
    print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)")

nbest

This parameter indicates the number of best hypotheses to return, which is a property that is not possible with the greedy decoder. For instance, by setting nbest=3 when constructing the beam search decoder earlier, we can now access the hypotheses with the top 3 scores.

for i in range(3):
    transcript = " ".join(beam_search_result[0][i].words).strip()
    score = beam_search_result[0][i].score
    print(f"{transcript} (score: {score})")

Out:

i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.8238231825794)
i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.8580900895563)
i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.015467226502)

beam size

The beam_size parameter determines the maximum number of best hypotheses to hold after each decoding step. Using larger beam sizes allows for exploring a larger range of possible hypotheses which can produce hypotheses with higher scores, but it is computationally more expensive and does not provide additional gains beyond a certain point.

In the example below, we see improvement in decoding quality as we increase beam size from 1 to 5 to 50, but notice how using a beam size of 500 provides the same output as beam size 50 while increase the computation time.

beam_sizes = [1, 5, 50, 500]

for beam_size in beam_sizes:
    beam_search_decoder = ctc_decoder(
        lexicon=files.lexicon,
        tokens=files.tokens,
        lm=files.lm,
        beam_size=beam_size,
        lm_weight=LM_WEIGHT,
        word_score=WORD_SCORE,
    )

    print_decoded(beam_search_decoder, emission, "beam size", beam_size)

Out:

beam size 1  : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.2201 secs)
beam size 5  : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0646 secs)
beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2912 secs)
beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.7506 secs)

beam size token

The beam_size_token parameter corresponds to the number of tokens to consider for expanding each hypothesis at the decoding step. Exploring a larger number of next possible tokens increases the range of potential hypotheses at the cost of computation.

num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]

for beam_size_token in beam_size_tokens:
    beam_search_decoder = ctc_decoder(
        lexicon=files.lexicon,
        tokens=files.tokens,
        lm=files.lm,
        beam_size_token=beam_size_token,
        lm_weight=LM_WEIGHT,
        word_score=WORD_SCORE,
    )

    print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token)

Out:

beam size token 1  : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.3286 secs)
beam size token 5  : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.2777 secs)
beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.2314 secs)
beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4088 secs)

beam threshold

The beam_threshold parameter is used to prune the stored hypotheses set at each decoding step, removing hypotheses whose scores are greater than beam_threshold away from the highest scoring hypothesis. There is a balance between choosing smaller thresholds to prune more hypotheses and reduce the search space, and choosing a large enough threshold such that plausible hypotheses are not pruned.

beam_thresholds = [1, 5, 10, 25]

for beam_threshold in beam_thresholds:
    beam_search_decoder = ctc_decoder(
        lexicon=files.lexicon,
        tokens=files.tokens,
        lm=files.lm,
        beam_threshold=beam_threshold,
        lm_weight=LM_WEIGHT,
        word_score=WORD_SCORE,
    )

    print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold)

Out:

beam threshold 1  : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.0337 secs)
beam threshold 5  : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.0850 secs)
beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3094 secs)
beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2865 secs)

language model weight

The lm_weight parameter is the weight to assign to the language model score which to accumulate with the acoustic model score for determining the overall scores. Larger weights encourage the model to predict next words based on the language model, while smaller weights give more weight to the acoustic model score instead.

lm_weights = [0, LM_WEIGHT, 15]

for lm_weight in lm_weights:
    beam_search_decoder = ctc_decoder(
        lexicon=files.lexicon,
        tokens=files.tokens,
        lm=files.lm,
        lm_weight=lm_weight,
        word_score=WORD_SCORE,
    )

    print_decoded(beam_search_decoder, emission, "lm weight", lm_weight)

Out:

lm weight 0  : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.3061 secs)
lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3269 secs)
lm weight 15 : was there in his was at some of what he said (score: 2918.98; 0.3175 secs)

additional parameters

Additional parameters that can be optimized include the following

  • word_score: score to add when word finishes

  • unk_score: unknown word appearance score to add

  • sil_score: silence appearance score to add

  • log_add: whether to use log add for lexicon Trie smearing

Total running time of the script: ( 3 minutes 5.586 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources