Note
Click here to download the full example code
ASR Inference with CTC Decoder
Author: Caroline Chen
This tutorial shows how to perform speech recognition inference using a CTC beam search decoder with lexicon constraint and KenLM language model support. We demonstrate this on a pretrained wav2vec 2.0 model trained using CTC loss.
Overview
Beam search decoding works by iteratively expanding text hypotheses (beams) with next possible characters, and maintaining only the hypotheses with the highest scores at each time step. A language model can be incorporated into the scoring computation, and adding a lexicon constraint restricts the next possible tokens for the hypotheses so that only words from the lexicon can be generated.
The underlying implementation is ported from Flashlight’s beam search decoder. A mathematical formula for the decoder optimization can be found in the Wav2Letter paper, and a more detailed algorithm can be found in this blog.
Running ASR inference using a CTC Beam Search decoder with a KenLM language model and lexicon constraint requires the following components
Acoustic Model: model predicting phonetics from audio waveforms
Tokens: the possible predicted tokens from the acoustic model
Lexicon: mapping between possible words and their corresponding tokens sequence
KenLM: n-gram language model trained with the KenLM library
Preparation
First we import the necessary utilities and fetch the data that we are working with
import time
from typing import List
import IPython
import matplotlib.pyplot as plt
import torch
import torchaudio
try:
from torchaudio.models.decoder import ctc_decoder
except ModuleNotFoundError:
try:
import google.colab
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
"""
)
except ModuleNotFoundError:
pass
raise
Acoustic Model and Data
We use the pretrained Wav2Vec 2.0
Base model that is finetuned on 10 min of the LibriSpeech
dataset, which can be loaded in using
torchaudio.pipelines()
. For more detail on running Wav2Vec 2.0 speech
recognition pipelines in torchaudio, please refer to this
tutorial.
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
acoustic_model = bundle.get_model()
Out:
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth
0%| | 0.00/360M [00:00<?, ?B/s]
1%| | 2.18M/360M [00:00<00:20, 18.2MB/s]
4%|4 | 15.5M/360M [00:00<00:04, 82.7MB/s]
7%|6 | 24.0M/360M [00:00<00:05, 67.0MB/s]
9%|8 | 32.0M/360M [00:00<00:05, 62.8MB/s]
13%|#3 | 48.0M/360M [00:00<00:03, 92.8MB/s]
17%|#7 | 62.8M/360M [00:00<00:02, 108MB/s]
21%|## | 73.9M/360M [00:00<00:03, 87.6MB/s]
23%|##3 | 83.3M/360M [00:01<00:04, 58.4MB/s]
29%|##8 | 104M/360M [00:01<00:03, 87.1MB/s]
32%|###1 | 115M/360M [00:01<00:02, 90.4MB/s]
36%|###5 | 128M/360M [00:01<00:02, 84.9MB/s]
40%|###9 | 144M/360M [00:01<00:02, 99.2MB/s]
43%|####3 | 155M/360M [00:01<00:02, 100MB/s]
46%|####5 | 166M/360M [00:02<00:02, 84.6MB/s]
49%|####8 | 176M/360M [00:02<00:02, 79.1MB/s]
53%|#####3 | 192M/360M [00:02<00:01, 94.1MB/s]
56%|#####6 | 202M/360M [00:02<00:02, 71.6MB/s]
58%|#####8 | 210M/360M [00:02<00:02, 75.3MB/s]
62%|######2 | 224M/360M [00:02<00:01, 79.9MB/s]
67%|######6 | 240M/360M [00:03<00:01, 91.8MB/s]
69%|######9 | 249M/360M [00:03<00:01, 73.5MB/s]
71%|#######1 | 257M/360M [00:03<00:01, 75.6MB/s]
75%|#######5 | 271M/360M [00:03<00:01, 86.5MB/s]
78%|#######7 | 280M/360M [00:03<00:01, 84.0MB/s]
81%|######## | 290M/360M [00:03<00:00, 90.2MB/s]
84%|########4 | 304M/360M [00:03<00:00, 94.1MB/s]
87%|########6 | 313M/360M [00:04<00:00, 69.3MB/s]
91%|#########1| 328M/360M [00:04<00:00, 87.5MB/s]
98%|#########7| 352M/360M [00:04<00:00, 124MB/s]
100%|##########| 360M/360M [00:04<00:00, 87.6MB/s]
We will load a sample from the LibriSpeech test-other dataset.
hub_dir = torch.hub.get_dir()
speech_url = "https://download.pytorch.org/torchaudio/tutorial-assets/ctc-decoding/1688-142285-0007.wav"
speech_file = f"{hub_dir}/speech.wav"
torch.hub.download_url_to_file(speech_url, speech_file)
IPython.display.Audio(speech_file)
Out:
0%| | 0.00/441k [00:00<?, ?B/s]
100%|##########| 441k/441k [00:00<00:00, 7.35MB/s]
The transcript corresponding to this audio file is
waveform, sample_rate = torchaudio.load(speech_file)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
Files and Data for Decoder
Next, we load in our token, lexicon, and KenLM data, which are used by the decoder to predict words from the acoustic model output. Pretrained files for the LibriSpeech dataset can be downloaded through torchaudio, or the user can provide their own files.
Tokens
The tokens are the possible symbols that the acoustic model can predict, including the blank and silent symbols. It can either be passed in as a file, where each line consists of the tokens corresponding to the same index, or as a list of tokens, each mapping to a unique index.
# tokens.txt
_
|
e
t
...
Out:
['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z']
Lexicon
The lexicon is a mapping from words to their corresponding tokens sequence, and is used to restrict the search space of the decoder to only words from the lexicon. The expected format of the lexicon file is a line per word, with a word followed by its space-split tokens.
# lexcion.txt
a a |
able a b l e |
about a b o u t |
...
...
KenLM
This is an n-gram language model trained with the KenLM
library. Both the .arpa
or
the binarized .bin
LM can be used, but the binary format is
recommended for faster loading.
The language model used in this tutorial is a 4-gram KenLM trained using LibriSpeech.
Downloading Pretrained Files
Pretrained files for the LibriSpeech dataset can be downloaded using
download_pretrained_files
.
Note: this cell may take a couple of minutes to run, as the language model can be large
from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
print(files)
Out:
0%| | 0.00/4.97M [00:00<?, ?B/s]
100%|##########| 4.97M/4.97M [00:00<00:00, 69.4MB/s]
0%| | 0.00/57.0 [00:00<?, ?B/s]
100%|##########| 57.0/57.0 [00:00<00:00, 41.5kB/s]
0%| | 0.00/2.91G [00:00<?, ?B/s]
0%| | 4.37M/2.91G [00:00<01:08, 45.7MB/s]
0%| | 8.73M/2.91G [00:00<01:09, 44.6MB/s]
1%| | 23.7M/2.91G [00:00<00:32, 95.2MB/s]
1%|1 | 32.9M/2.91G [00:00<01:20, 38.4MB/s]
2%|1 | 46.8M/2.91G [00:01<01:05, 46.9MB/s]
2%|1 | 53.0M/2.91G [00:01<01:19, 38.4MB/s]
2%|1 | 57.9M/2.91G [00:01<01:26, 35.3MB/s]
2%|2 | 62.0M/2.91G [00:01<01:25, 35.9MB/s]
2%|2 | 66.0M/2.91G [00:01<01:45, 29.0MB/s]
3%|2 | 79.7M/2.91G [00:02<01:12, 42.0MB/s]
3%|2 | 84.1M/2.91G [00:02<01:20, 37.5MB/s]
3%|3 | 93.2M/2.91G [00:02<01:03, 48.0MB/s]
3%|3 | 98.6M/2.91G [00:02<01:01, 49.2MB/s]
3%|3 | 104M/2.91G [00:02<01:03, 47.7MB/s]
4%|4 | 127M/2.91G [00:02<00:35, 84.0MB/s]
5%|4 | 135M/2.91G [00:02<00:45, 65.4MB/s]
5%|4 | 144M/2.91G [00:03<00:45, 64.8MB/s]
5%|5 | 150M/2.91G [00:03<00:49, 60.4MB/s]
5%|5 | 160M/2.91G [00:03<00:49, 59.6MB/s]
6%|5 | 168M/2.91G [00:03<00:45, 64.9MB/s]
6%|5 | 176M/2.91G [00:03<00:49, 59.6MB/s]
6%|6 | 182M/2.91G [00:03<00:57, 51.2MB/s]
7%|6 | 202M/2.91G [00:03<00:33, 85.9MB/s]
7%|7 | 212M/2.91G [00:04<00:36, 80.4MB/s]
8%|7 | 224M/2.91G [00:04<00:33, 85.5MB/s]
8%|7 | 233M/2.91G [00:04<00:37, 76.6MB/s]
8%|8 | 241M/2.91G [00:04<00:42, 68.4MB/s]
9%|8 | 255M/2.91G [00:04<00:38, 73.3MB/s]
9%|8 | 262M/2.91G [00:04<00:45, 62.3MB/s]
9%|9 | 268M/2.91G [00:05<00:49, 57.1MB/s]
9%|9 | 281M/2.91G [00:05<00:42, 67.1MB/s]
10%|9 | 288M/2.91G [00:05<00:43, 65.3MB/s]
10%|9 | 295M/2.91G [00:05<00:49, 56.5MB/s]
10%|# | 300M/2.91G [00:05<00:52, 53.4MB/s]
10%|# | 306M/2.91G [00:06<01:28, 31.6MB/s]
10%|# | 310M/2.91G [00:06<01:38, 28.4MB/s]
11%|# | 320M/2.91G [00:06<01:09, 40.3MB/s]
11%|# | 327M/2.91G [00:06<00:59, 46.7MB/s]
11%|#1 | 333M/2.91G [00:06<01:02, 44.4MB/s]
11%|#1 | 338M/2.91G [00:07<01:41, 27.4MB/s]
12%|#1 | 352M/2.91G [00:07<01:07, 40.6MB/s]
12%|#1 | 357M/2.91G [00:07<01:08, 40.0MB/s]
12%|#2 | 362M/2.91G [00:07<01:09, 39.3MB/s]
13%|#2 | 382M/2.91G [00:07<00:37, 72.0MB/s]
13%|#3 | 398M/2.91G [00:07<00:28, 93.7MB/s]
14%|#3 | 410M/2.91G [00:07<00:27, 96.4MB/s]
14%|#4 | 421M/2.91G [00:08<00:41, 64.8MB/s]
14%|#4 | 432M/2.91G [00:08<00:38, 69.4MB/s]
15%|#5 | 448M/2.91G [00:08<00:34, 77.5MB/s]
15%|#5 | 456M/2.91G [00:08<00:40, 65.4MB/s]
16%|#5 | 464M/2.91G [00:08<00:56, 46.7MB/s]
16%|#5 | 469M/2.91G [00:09<00:58, 45.1MB/s]
16%|#6 | 480M/2.91G [00:09<00:53, 48.5MB/s]
17%|#6 | 496M/2.91G [00:09<00:40, 64.0MB/s]
17%|#7 | 512M/2.91G [00:09<00:31, 82.1MB/s]
18%|#7 | 527M/2.91G [00:09<00:29, 86.2MB/s]
18%|#7 | 536M/2.91G [00:09<00:31, 81.9MB/s]
18%|#8 | 545M/2.91G [00:10<00:38, 65.9MB/s]
19%|#8 | 560M/2.91G [00:10<00:36, 69.8MB/s]
19%|#9 | 572M/2.91G [00:10<00:31, 81.0MB/s]
20%|#9 | 581M/2.91G [00:10<00:34, 72.2MB/s]
20%|#9 | 591M/2.91G [00:10<00:32, 77.2MB/s]
20%|## | 599M/2.91G [00:10<00:37, 66.6MB/s]
20%|## | 607M/2.91G [00:11<00:35, 69.2MB/s]
21%|## | 614M/2.91G [00:11<00:41, 60.3MB/s]
21%|## | 623M/2.91G [00:11<00:37, 65.5MB/s]
21%|##1 | 630M/2.91G [00:11<00:38, 64.7MB/s]
21%|##1 | 640M/2.91G [00:11<00:37, 65.7MB/s]
22%|##1 | 646M/2.91G [00:11<00:46, 52.1MB/s]
22%|##1 | 652M/2.91G [00:12<01:04, 38.1MB/s]
22%|##2 | 656M/2.91G [00:12<01:24, 28.8MB/s]
23%|##2 | 672M/2.91G [00:12<00:51, 46.8MB/s]
23%|##2 | 684M/2.91G [00:12<00:39, 61.1MB/s]
23%|##3 | 692M/2.91G [00:12<00:50, 47.5MB/s]
24%|##3 | 702M/2.91G [00:13<00:42, 56.3MB/s]
24%|##3 | 709M/2.91G [00:13<00:44, 53.2MB/s]
24%|##4 | 720M/2.91G [00:13<00:40, 59.2MB/s]
24%|##4 | 727M/2.91G [00:13<00:46, 51.3MB/s]
25%|##4 | 732M/2.91G [00:13<00:51, 45.8MB/s]
25%|##4 | 737M/2.91G [00:14<01:20, 29.1MB/s]
25%|##4 | 745M/2.91G [00:14<01:03, 36.9MB/s]
25%|##5 | 752M/2.91G [00:14<00:57, 40.4MB/s]
26%|##5 | 764M/2.91G [00:14<00:41, 56.3MB/s]
26%|##5 | 771M/2.91G [00:14<00:46, 49.6MB/s]
26%|##6 | 784M/2.91G [00:14<00:38, 59.1MB/s]
27%|##6 | 791M/2.91G [00:14<00:38, 60.1MB/s]
27%|##6 | 800M/2.91G [00:15<00:36, 62.3MB/s]
27%|##7 | 815M/2.91G [00:15<00:32, 70.1MB/s]
28%|##7 | 822M/2.91G [00:15<00:36, 62.0MB/s]
28%|##7 | 832M/2.91G [00:15<00:34, 64.4MB/s]
28%|##8 | 838M/2.91G [00:15<00:37, 59.6MB/s]
28%|##8 | 844M/2.91G [00:15<00:42, 53.1MB/s]
29%|##8 | 849M/2.91G [00:15<00:41, 53.9MB/s]
29%|##9 | 864M/2.91G [00:16<00:28, 78.2MB/s]
30%|##9 | 880M/2.91G [00:16<00:22, 98.9MB/s]
30%|##9 | 890M/2.91G [00:16<00:37, 58.8MB/s]
30%|### | 898M/2.91G [00:16<00:43, 50.0MB/s]
30%|### | 906M/2.91G [00:16<00:38, 56.4MB/s]
31%|### | 914M/2.91G [00:17<00:40, 53.9MB/s]
31%|### | 920M/2.91G [00:17<00:46, 46.4MB/s]
32%|###1 | 944M/2.91G [00:17<00:28, 73.7MB/s]
32%|###2 | 960M/2.91G [00:17<00:28, 74.8MB/s]
32%|###2 | 967M/2.91G [00:17<00:31, 67.5MB/s]
33%|###2 | 976M/2.91G [00:17<00:32, 65.5MB/s]
33%|###3 | 987M/2.91G [00:18<00:27, 74.9MB/s]
33%|###3 | 995M/2.91G [00:18<00:28, 74.3MB/s]
34%|###3 | 0.98G/2.91G [00:18<00:25, 81.4MB/s]
34%|###4 | 0.99G/2.91G [00:18<00:35, 57.5MB/s]
35%|###4 | 1.01G/2.91G [00:18<00:27, 75.6MB/s]
35%|###4 | 1.02G/2.91G [00:18<00:26, 76.5MB/s]
35%|###5 | 1.02G/2.91G [00:18<00:27, 72.6MB/s]
35%|###5 | 1.03G/2.91G [00:19<00:28, 70.6MB/s]
36%|###5 | 1.04G/2.91G [00:19<00:25, 80.2MB/s]
36%|###6 | 1.05G/2.91G [00:19<00:27, 73.1MB/s]
36%|###6 | 1.06G/2.91G [00:19<00:32, 60.7MB/s]
37%|###6 | 1.06G/2.91G [00:19<00:48, 40.7MB/s]
37%|###6 | 1.07G/2.91G [00:20<00:51, 38.6MB/s]
37%|###7 | 1.08G/2.91G [00:20<00:41, 47.6MB/s]
38%|###7 | 1.09G/2.91G [00:20<00:28, 68.5MB/s]
38%|###7 | 1.10G/2.91G [00:20<00:29, 67.0MB/s]
38%|###8 | 1.11G/2.91G [00:20<00:31, 60.9MB/s]
38%|###8 | 1.12G/2.91G [00:20<00:29, 65.5MB/s]
39%|###8 | 1.12G/2.91G [00:20<00:31, 61.8MB/s]
39%|###8 | 1.13G/2.91G [00:21<00:36, 52.6MB/s]
39%|###9 | 1.14G/2.91G [00:21<00:34, 55.0MB/s]
39%|###9 | 1.15G/2.91G [00:21<00:46, 40.5MB/s]
40%|#### | 1.17G/2.91G [00:21<00:22, 82.9MB/s]
41%|#### | 1.19G/2.91G [00:21<00:19, 95.3MB/s]
41%|####1 | 1.20G/2.91G [00:21<00:20, 91.6MB/s]
42%|####1 | 1.21G/2.91G [00:22<00:20, 89.8MB/s]
42%|####1 | 1.22G/2.91G [00:22<00:20, 86.6MB/s]
42%|####2 | 1.23G/2.91G [00:22<00:20, 87.8MB/s]
43%|####2 | 1.25G/2.91G [00:22<00:20, 87.4MB/s]
43%|####3 | 1.26G/2.91G [00:22<00:30, 58.9MB/s]
44%|####3 | 1.27G/2.91G [00:22<00:26, 66.8MB/s]
44%|####3 | 1.28G/2.91G [00:23<00:26, 65.0MB/s]
44%|####4 | 1.28G/2.91G [00:23<00:31, 54.8MB/s]
44%|####4 | 1.29G/2.91G [00:23<00:28, 60.2MB/s]
45%|####5 | 1.31G/2.91G [00:23<00:19, 89.4MB/s]
45%|####5 | 1.32G/2.91G [00:23<00:20, 81.3MB/s]
46%|####5 | 1.33G/2.91G [00:23<00:20, 81.9MB/s]
46%|####6 | 1.34G/2.91G [00:23<00:17, 95.5MB/s]
46%|####6 | 1.35G/2.91G [00:24<00:20, 80.6MB/s]
47%|####6 | 1.36G/2.91G [00:24<00:21, 76.7MB/s]
47%|####7 | 1.37G/2.91G [00:24<00:35, 46.2MB/s]
47%|####7 | 1.37G/2.91G [00:24<00:34, 47.3MB/s]
48%|####8 | 1.40G/2.91G [00:24<00:19, 84.6MB/s]
48%|####8 | 1.41G/2.91G [00:24<00:17, 92.5MB/s]
49%|####8 | 1.42G/2.91G [00:25<00:16, 96.6MB/s]
49%|####9 | 1.44G/2.91G [00:25<00:16, 96.9MB/s]
50%|####9 | 1.45G/2.91G [00:25<00:14, 111MB/s]
50%|##### | 1.46G/2.91G [00:25<00:13, 117MB/s]
51%|##### | 1.48G/2.91G [00:25<00:18, 83.4MB/s]
51%|#####1 | 1.49G/2.91G [00:26<00:30, 50.8MB/s]
52%|#####1 | 1.50G/2.91G [00:26<00:24, 63.0MB/s]
52%|#####2 | 1.52G/2.91G [00:26<00:18, 81.1MB/s]
53%|#####2 | 1.53G/2.91G [00:26<00:19, 75.8MB/s]
53%|#####2 | 1.54G/2.91G [00:26<00:23, 64.0MB/s]
53%|#####3 | 1.55G/2.91G [00:26<00:21, 68.4MB/s]
53%|#####3 | 1.56G/2.91G [00:27<00:19, 75.9MB/s]
54%|#####3 | 1.57G/2.91G [00:27<00:17, 81.5MB/s]
54%|#####4 | 1.58G/2.91G [00:27<00:21, 65.4MB/s]
54%|#####4 | 1.58G/2.91G [00:27<00:26, 53.4MB/s]
55%|#####4 | 1.59G/2.91G [00:27<00:25, 54.9MB/s]
55%|#####4 | 1.60G/2.91G [00:27<00:26, 52.2MB/s]
55%|#####5 | 1.61G/2.91G [00:28<00:21, 64.0MB/s]
56%|#####5 | 1.62G/2.91G [00:28<00:26, 52.8MB/s]
56%|#####5 | 1.62G/2.91G [00:28<00:28, 48.1MB/s]
56%|#####5 | 1.63G/2.91G [00:28<00:38, 35.4MB/s]
56%|#####6 | 1.64G/2.91G [00:28<00:26, 50.8MB/s]
57%|#####6 | 1.65G/2.91G [00:28<00:27, 50.1MB/s]
57%|#####6 | 1.66G/2.91G [00:29<00:25, 53.3MB/s]
57%|#####7 | 1.66G/2.91G [00:29<00:31, 43.1MB/s]
57%|#####7 | 1.67G/2.91G [00:29<00:25, 53.1MB/s]
58%|#####7 | 1.68G/2.91G [00:29<00:31, 42.1MB/s]
58%|#####7 | 1.68G/2.91G [00:29<00:35, 36.9MB/s]
59%|#####8 | 1.70G/2.91G [00:30<00:18, 69.3MB/s]
59%|#####9 | 1.72G/2.91G [00:30<00:15, 84.9MB/s]
59%|#####9 | 1.73G/2.91G [00:30<00:16, 76.2MB/s]
60%|#####9 | 1.74G/2.91G [00:30<00:19, 63.8MB/s]
60%|#####9 | 1.74G/2.91G [00:30<00:17, 69.5MB/s]
60%|###### | 1.75G/2.91G [00:31<00:27, 45.7MB/s]
60%|###### | 1.76G/2.91G [00:31<00:30, 40.3MB/s]
61%|######1 | 1.78G/2.91G [00:31<00:18, 66.7MB/s]
62%|######1 | 1.79G/2.91G [00:31<00:15, 77.7MB/s]
62%|######1 | 1.80G/2.91G [00:31<00:19, 61.4MB/s]
62%|######2 | 1.81G/2.91G [00:31<00:19, 61.3MB/s]
63%|######2 | 1.82G/2.91G [00:32<00:24, 47.4MB/s]
63%|######2 | 1.83G/2.91G [00:32<00:25, 45.4MB/s]
63%|######2 | 1.83G/2.91G [00:32<00:35, 32.3MB/s]
63%|######3 | 1.84G/2.91G [00:32<00:30, 38.2MB/s]
64%|######3 | 1.85G/2.91G [00:33<00:22, 50.3MB/s]
64%|######3 | 1.86G/2.91G [00:33<00:19, 57.8MB/s]
64%|######4 | 1.87G/2.91G [00:33<00:20, 55.8MB/s]
64%|######4 | 1.87G/2.91G [00:33<00:20, 54.3MB/s]
65%|######4 | 1.88G/2.91G [00:33<00:20, 54.2MB/s]
65%|######5 | 1.89G/2.91G [00:33<00:14, 77.6MB/s]
66%|######5 | 1.91G/2.91G [00:33<00:09, 112MB/s]
66%|######6 | 1.92G/2.91G [00:33<00:11, 92.4MB/s]
67%|######6 | 1.94G/2.91G [00:34<00:10, 95.6MB/s]
67%|######6 | 1.95G/2.91G [00:34<00:10, 102MB/s]
67%|######7 | 1.96G/2.91G [00:34<00:10, 99.3MB/s]
68%|######7 | 1.97G/2.91G [00:34<00:10, 95.6MB/s]
68%|######7 | 1.98G/2.91G [00:34<00:11, 86.7MB/s]
68%|######8 | 1.99G/2.91G [00:34<00:10, 92.1MB/s]
69%|######8 | 2.00G/2.91G [00:34<00:14, 69.9MB/s]
69%|######8 | 2.00G/2.91G [00:35<00:15, 61.5MB/s]
69%|######9 | 2.02G/2.91G [00:35<00:14, 67.1MB/s]
70%|######9 | 2.03G/2.91G [00:35<00:14, 65.5MB/s]
70%|######9 | 2.03G/2.91G [00:35<00:19, 49.0MB/s]
70%|####### | 2.04G/2.91G [00:35<00:20, 45.9MB/s]
71%|####### | 2.05G/2.91G [00:35<00:14, 62.3MB/s]
71%|####### | 2.06G/2.91G [00:36<00:16, 56.3MB/s]
71%|####### | 2.07G/2.91G [00:36<00:19, 46.0MB/s]
71%|#######1 | 2.07G/2.91G [00:36<00:16, 55.6MB/s]
71%|#######1 | 2.08G/2.91G [00:36<00:27, 32.9MB/s]
72%|#######1 | 2.08G/2.91G [00:37<00:25, 34.3MB/s]
72%|#######2 | 2.10G/2.91G [00:37<00:18, 48.2MB/s]
72%|#######2 | 2.11G/2.91G [00:37<00:14, 60.5MB/s]
73%|#######2 | 2.12G/2.91G [00:37<00:11, 76.0MB/s]
73%|#######3 | 2.13G/2.91G [00:37<00:14, 56.7MB/s]
73%|#######3 | 2.14G/2.91G [00:37<00:15, 54.7MB/s]
74%|#######3 | 2.14G/2.91G [00:37<00:15, 52.8MB/s]
74%|#######4 | 2.16G/2.91G [00:38<00:13, 61.5MB/s]
74%|#######4 | 2.16G/2.91G [00:38<00:12, 64.0MB/s]
75%|#######4 | 2.17G/2.91G [00:38<00:11, 68.8MB/s]
75%|#######4 | 2.18G/2.91G [00:38<00:11, 70.0MB/s]
75%|#######5 | 2.19G/2.91G [00:38<00:13, 58.3MB/s]
76%|#######5 | 2.20G/2.91G [00:38<00:10, 72.6MB/s]
76%|#######6 | 2.22G/2.91G [00:38<00:07, 98.6MB/s]
77%|#######6 | 2.23G/2.91G [00:39<00:09, 75.4MB/s]
77%|#######6 | 2.24G/2.91G [00:39<00:11, 64.2MB/s]
77%|#######7 | 2.25G/2.91G [00:39<00:08, 80.3MB/s]
78%|#######7 | 2.26G/2.91G [00:39<00:08, 85.0MB/s]
78%|#######7 | 2.27G/2.91G [00:39<00:09, 76.4MB/s]
78%|#######8 | 2.28G/2.91G [00:39<00:09, 74.9MB/s]
79%|#######8 | 2.30G/2.91G [00:40<00:07, 83.4MB/s]
79%|#######9 | 2.31G/2.91G [00:40<00:06, 95.1MB/s]
80%|#######9 | 2.32G/2.91G [00:40<00:12, 51.4MB/s]
80%|######## | 2.33G/2.91G [00:40<00:10, 58.7MB/s]
81%|######## | 2.34G/2.91G [00:40<00:08, 68.9MB/s]
81%|######## | 2.35G/2.91G [00:41<00:08, 68.8MB/s]
81%|########1 | 2.36G/2.91G [00:41<00:10, 58.4MB/s]
82%|########1 | 2.37G/2.91G [00:41<00:07, 75.3MB/s]
82%|########1 | 2.38G/2.91G [00:41<00:07, 73.2MB/s]
82%|########2 | 2.39G/2.91G [00:41<00:10, 53.8MB/s]
82%|########2 | 2.40G/2.91G [00:41<00:11, 47.8MB/s]
83%|########2 | 2.40G/2.91G [00:42<00:12, 42.3MB/s]
83%|########2 | 2.41G/2.91G [00:42<00:14, 38.5MB/s]
83%|########3 | 2.42G/2.91G [00:42<00:08, 58.5MB/s]
84%|########3 | 2.43G/2.91G [00:42<00:07, 72.6MB/s]
84%|########3 | 2.44G/2.91G [00:42<00:06, 78.8MB/s]
84%|########4 | 2.45G/2.91G [00:42<00:05, 87.3MB/s]
85%|########4 | 2.46G/2.91G [00:42<00:05, 80.8MB/s]
85%|########4 | 2.47G/2.91G [00:43<00:08, 54.1MB/s]
85%|########5 | 2.49G/2.91G [00:43<00:06, 71.8MB/s]
86%|########5 | 2.50G/2.91G [00:43<00:05, 82.1MB/s]
86%|########6 | 2.51G/2.91G [00:43<00:06, 71.9MB/s]
86%|########6 | 2.51G/2.91G [00:43<00:06, 60.9MB/s]
87%|########6 | 2.52G/2.91G [00:43<00:05, 71.1MB/s]
87%|########7 | 2.53G/2.91G [00:44<00:06, 66.9MB/s]
88%|########7 | 2.55G/2.91G [00:44<00:05, 70.8MB/s]
88%|########8 | 2.56G/2.91G [00:44<00:04, 82.1MB/s]
88%|########8 | 2.57G/2.91G [00:44<00:05, 71.5MB/s]
89%|########8 | 2.58G/2.91G [00:44<00:06, 58.9MB/s]
89%|########9 | 2.59G/2.91G [00:44<00:04, 77.1MB/s]
90%|########9 | 2.61G/2.91G [00:45<00:03, 97.3MB/s]
90%|######### | 2.62G/2.91G [00:45<00:02, 111MB/s]
91%|######### | 2.63G/2.91G [00:45<00:03, 80.0MB/s]
91%|######### | 2.64G/2.91G [00:45<00:04, 71.0MB/s]
91%|#########1| 2.66G/2.91G [00:45<00:03, 71.9MB/s]
92%|#########1| 2.67G/2.91G [00:45<00:02, 87.3MB/s]
92%|#########2| 2.68G/2.91G [00:46<00:04, 57.4MB/s]
92%|#########2| 2.69G/2.91G [00:46<00:03, 61.0MB/s]
93%|#########2| 2.70G/2.91G [00:46<00:02, 79.6MB/s]
93%|#########3| 2.71G/2.91G [00:46<00:02, 79.9MB/s]
94%|#########3| 2.72G/2.91G [00:46<00:03, 65.3MB/s]
94%|#########3| 2.73G/2.91G [00:46<00:02, 73.6MB/s]
94%|#########4| 2.74G/2.91G [00:47<00:02, 61.1MB/s]
95%|#########4| 2.75G/2.91G [00:47<00:03, 54.7MB/s]
95%|#########4| 2.76G/2.91G [00:47<00:02, 55.5MB/s]
95%|#########4| 2.76G/2.91G [00:47<00:03, 52.8MB/s]
96%|#########5| 2.78G/2.91G [00:47<00:01, 80.1MB/s]
96%|#########5| 2.79G/2.91G [00:47<00:01, 80.6MB/s]
96%|#########6| 2.80G/2.91G [00:47<00:01, 74.1MB/s]
96%|#########6| 2.81G/2.91G [00:48<00:01, 80.6MB/s]
97%|#########6| 2.81G/2.91G [00:48<00:01, 55.8MB/s]
97%|#########6| 2.82G/2.91G [00:48<00:01, 57.5MB/s]
97%|#########7| 2.83G/2.91G [00:48<00:01, 53.5MB/s]
97%|#########7| 2.83G/2.91G [00:48<00:02, 39.8MB/s]
98%|#########7| 2.84G/2.91G [00:49<00:01, 45.3MB/s]
98%|#########7| 2.85G/2.91G [00:49<00:01, 44.5MB/s]
98%|#########8| 2.86G/2.91G [00:49<00:00, 54.8MB/s]
99%|#########8| 2.87G/2.91G [00:49<00:00, 66.1MB/s]
99%|#########9| 2.89G/2.91G [00:49<00:00, 69.4MB/s]
100%|#########9| 2.91G/2.91G [00:49<00:00, 82.5MB/s]
100%|##########| 2.91G/2.91G [00:50<00:00, 62.4MB/s]
PretrainedFiles(lexicon='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt', tokens='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt', lm='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin')
Construct Decoders
In this tutorial, we construct both a beam search decoder and a greedy decoder for comparison.
Beam Search Decoder
The decoder can be constructed using the factory function
ctc_decoder
.
In addition to the previously mentioned components, it also takes in various beam
search decoding parameters and token/word parameters.
This decoder can also be run without a language model by passing in None into the lm parameter.
from torchaudio.models.decoder import ctc_decoder
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=3,
beam_size=1500,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
Greedy Decoder
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> List[str]:
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
return joined.replace("|", " ").strip().split()
greedy_decoder = GreedyCTCDecoder(tokens)
Run Inference
Now that we have the data, acoustic model, and decoder, we can perform
inference. The output of the beam search decoder is of type
torchaudio.models.decoder.CTCHypothesis()
, consisting of the
predicted token IDs, corresponding words, hypothesis score, and timesteps
corresponding to the token IDs. Recall the transcript corresponding to the
waveform is
actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = actual_transcript.split()
emission, _ = acoustic_model(waveform)
The greedy decoder give the following result.
greedy_result = greedy_decoder(emission[0])
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")
Out:
Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid
WER: 0.38095238095238093
Using the beam search decoder:
beam_search_result = beam_search_decoder(emission)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
actual_transcript
)
print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer}")
Out:
Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said
WER: 0.047619047619047616
We see that the transcript with the lexicon-constrained beam search decoder produces a more accurate result consisting of real words, while the greedy decoder can predict incorrectly spelled words like “affrayd” and “shoktd”.
Timestep Alignments
Recall that one of the components of the resulting Hypotheses is timesteps corresponding to the token IDs.
timesteps = beam_search_result[0][0].timesteps
predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)
print(predicted_tokens, len(predicted_tokens))
print(timesteps, timesteps.shape[0])
Out:
['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99
tensor([ 0, 31, 33, 36, 39, 41, 42, 44, 46, 48, 49, 52, 54, 58,
64, 66, 69, 73, 74, 76, 80, 82, 84, 86, 88, 94, 97, 107,
111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157,
159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191,
193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250,
251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289,
292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331,
353], dtype=torch.int32) 99
Below, we visualize the token timestep alignments relative to the original waveform.
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
Beam Search Decoder Parameters
In this section, we go a little bit more in depth about some different
parameters and tradeoffs. For the full list of customizable parameters,
please refer to the
documentation
.
Helper Function
def print_decoded(decoder, emission, param, param_value):
start_time = time.monotonic()
result = decoder(emission)
decode_time = time.monotonic() - start_time
transcript = " ".join(result[0][0].words).lower().strip()
score = result[0][0].score
print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)")
nbest
This parameter indicates the number of best hypotheses to return, which
is a property that is not possible with the greedy decoder. For
instance, by setting nbest=3
when constructing the beam search
decoder earlier, we can now access the hypotheses with the top 3 scores.
for i in range(3):
transcript = " ".join(beam_search_result[0][i].words).strip()
score = beam_search_result[0][i].score
print(f"{transcript} (score: {score})")
Out:
i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.8238231825794)
i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.8580900895563)
i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.015467226502)
beam size
The beam_size
parameter determines the maximum number of best
hypotheses to hold after each decoding step. Using larger beam sizes
allows for exploring a larger range of possible hypotheses which can
produce hypotheses with higher scores, but it is computationally more
expensive and does not provide additional gains beyond a certain point.
In the example below, we see improvement in decoding quality as we increase beam size from 1 to 5 to 50, but notice how using a beam size of 500 provides the same output as beam size 50 while increase the computation time.
beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size=beam_size,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size", beam_size)
Out:
beam size 1 : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.2201 secs)
beam size 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0646 secs)
beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2912 secs)
beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.7506 secs)
beam size token
The beam_size_token
parameter corresponds to the number of tokens to
consider for expanding each hypothesis at the decoding step. Exploring a
larger number of next possible tokens increases the range of potential
hypotheses at the cost of computation.
num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token)
Out:
beam size token 1 : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.3286 secs)
beam size token 5 : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.2777 secs)
beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.2314 secs)
beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4088 secs)
beam threshold
The beam_threshold
parameter is used to prune the stored hypotheses
set at each decoding step, removing hypotheses whose scores are greater
than beam_threshold
away from the highest scoring hypothesis. There
is a balance between choosing smaller thresholds to prune more
hypotheses and reduce the search space, and choosing a large enough
threshold such that plausible hypotheses are not pruned.
beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold)
Out:
beam threshold 1 : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.0337 secs)
beam threshold 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.0850 secs)
beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3094 secs)
beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2865 secs)
language model weight
The lm_weight
parameter is the weight to assign to the language
model score which to accumulate with the acoustic model score for
determining the overall scores. Larger weights encourage the model to
predict next words based on the language model, while smaller weights
give more weight to the acoustic model score instead.
lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
lm_weight=lm_weight,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "lm weight", lm_weight)
Out:
lm weight 0 : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.3061 secs)
lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3269 secs)
lm weight 15 : was there in his was at some of what he said (score: 2918.98; 0.3175 secs)
additional parameters
Additional parameters that can be optimized include the following
word_score
: score to add when word finishesunk_score
: unknown word appearance score to addsil_score
: silence appearance score to addlog_add
: whether to use log add for lexicon Trie smearing
Total running time of the script: ( 3 minutes 5.586 seconds)