Note
Click here to download the full example code
ASR Inference with CTC Decoder¶
Author: Caroline Chen
This tutorial shows how to perform speech recognition inference using a CTC beam search decoder with lexicon constraint and KenLM language model support. We demonstrate this on a pretrained wav2vec 2.0 model trained using CTC loss.
Overview¶
Beam search decoding works by iteratively expanding text hypotheses (beams) with next possible characters, and maintaining only the hypotheses with the highest scores at each time step. A language model can be incorporated into the scoring computation, and adding a lexicon constraint restricts the next possible tokens for the hypotheses so that only words from the lexicon can be generated.
The underlying implementation is ported from Flashlight’s beam search decoder. A mathematical formula for the decoder optimization can be found in the Wav2Letter paper, and a more detailed algorithm can be found in this blog.
Running ASR inference using a CTC Beam Search decoder with a language model and lexicon constraint requires the following components
Acoustic Model: model predicting phonetics from audio waveforms
Tokens: the possible predicted tokens from the acoustic model
Lexicon: mapping between possible words and their corresponding tokens sequence
Language Model (LM): n-gram language model trained with the KenLM library, or custom language model that inherits
CTCDecoderLM
Acoustic Model and Set Up¶
First we import the necessary utilities and fetch the data that we are working with
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
1.13.0
0.13.0
import time
from typing import List
import IPython
import matplotlib.pyplot as plt
from torchaudio.models.decoder import ctc_decoder
from torchaudio.utils import download_asset
We use the pretrained Wav2Vec 2.0
Base model that is finetuned on 10 min of the LibriSpeech
dataset, which can be loaded in using
torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
.
For more detail on running Wav2Vec 2.0 speech
recognition pipelines in torchaudio, please refer to this
tutorial.
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
acoustic_model = bundle.get_model()
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth
0%| | 0.00/360M [00:00<?, ?B/s]
9%|8 | 31.8M/360M [00:00<00:01, 333MB/s]
18%|#7 | 64.2M/360M [00:00<00:00, 337MB/s]
27%|##6 | 96.4M/360M [00:00<00:00, 334MB/s]
36%|###5 | 128M/360M [00:00<00:00, 332MB/s]
45%|####4 | 160M/360M [00:00<00:00, 334MB/s]
54%|#####3 | 193M/360M [00:00<00:00, 336MB/s]
63%|######2 | 226M/360M [00:00<00:00, 338MB/s]
72%|#######1 | 259M/360M [00:00<00:00, 342MB/s]
81%|######## | 292M/360M [00:00<00:00, 342MB/s]
90%|######### | 325M/360M [00:01<00:00, 344MB/s]
99%|#########9| 358M/360M [00:01<00:00, 344MB/s]
100%|##########| 360M/360M [00:01<00:00, 339MB/s]
We will load a sample from the LibriSpeech test-other dataset.
speech_file = download_asset("tutorial-assets/ctc-decoding/1688-142285-0007.wav")
IPython.display.Audio(speech_file)
0%| | 0.00/441k [00:00<?, ?B/s]
100%|##########| 441k/441k [00:00<00:00, 94.9MB/s]
The transcript corresponding to this audio file is
waveform, sample_rate = torchaudio.load(speech_file)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
Files and Data for Decoder¶
Next, we load in our token, lexicon, and language model data, which are used by the decoder to predict words from the acoustic model output. Pretrained files for the LibriSpeech dataset can be downloaded through torchaudio, or the user can provide their own files.
Tokens¶
The tokens are the possible symbols that the acoustic model can predict, including the blank and silent symbols. It can either be passed in as a file, where each line consists of the tokens corresponding to the same index, or as a list of tokens, each mapping to a unique index.
# tokens.txt
_
|
e
t
...
['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z']
Lexicon¶
The lexicon is a mapping from words to their corresponding tokens sequence, and is used to restrict the search space of the decoder to only words from the lexicon. The expected format of the lexicon file is a line per word, with a word followed by its space-split tokens.
# lexcion.txt
a a |
able a b l e |
about a b o u t |
...
...
Language Model¶
A language model can be used in decoding to improve the results, by factoring in a language model score that represents the likelihood of the sequence into the beam search computation. Below, we outline the different forms of language models that are supported for decoding.
No Language Model¶
To create a decoder instance without a language model, set lm=None when initializing the decoder.
KenLM¶
This is an n-gram language model trained with the KenLM
library. Both the .arpa
or
the binarized .bin
LM can be used, but the binary format is
recommended for faster loading.
The language model used in this tutorial is a 4-gram KenLM trained using LibriSpeech.
Custom Language Model¶
Users can define their own custom language model in Python, whether
it be a statistical or neural network language model, using
CTCDecoderLM
and
CTCDecoderLMState
.
For instance, the following code creates a basic wrapper around a PyTorch
torch.nn.Module
language model.
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
class CustomLM(CTCDecoderLM):
"""Create a Python wrapper around `language_model` to feed to the decoder."""
def __init__(self, language_model: torch.nn.Module):
CTCDecoderLM.__init__(self)
self.language_model = language_model
self.sil = -1 # index for silent token in the language model
self.states = {}
language_model.eval()
def start(self, start_with_nothing: bool = False):
state = CTCDecoderLMState()
with torch.no_grad():
score = self.language_model(self.sil)
self.states[state] = score
return state
def score(self, state: CTCDecoderLMState, token_index: int):
outstate = state.child(token_index)
if outstate not in self.states:
score = self.language_model(token_index)
self.states[outstate] = score
score = self.states[outstate]
return outstate, score
def finish(self, state: CTCDecoderLMState):
return self.score(state, self.sil)
Downloading Pretrained Files¶
Pretrained files for the LibriSpeech dataset can be downloaded using
download_pretrained_files()
.
Note: this cell may take a couple of minutes to run, as the language model can be large
from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
print(files)
0%| | 0.00/4.97M [00:00<?, ?B/s]
100%|##########| 4.97M/4.97M [00:00<00:00, 71.4MB/s]
0%| | 0.00/57.0 [00:00<?, ?B/s]
100%|##########| 57.0/57.0 [00:00<00:00, 36.1kB/s]
0%| | 0.00/2.91G [00:00<?, ?B/s]
0%| | 14.8M/2.91G [00:00<00:49, 62.9MB/s]
1%| | 20.8M/2.91G [00:00<01:14, 41.8MB/s]
1%| | 24.8M/2.91G [00:00<01:16, 40.7MB/s]
1%|1 | 32.0M/2.91G [00:00<01:26, 35.9MB/s]
2%|1 | 48.0M/2.91G [00:01<00:58, 53.0MB/s]
2%|2 | 62.8M/2.91G [00:01<00:42, 72.8MB/s]
2%|2 | 71.0M/2.91G [00:01<00:44, 67.9MB/s]
3%|2 | 79.7M/2.91G [00:01<00:49, 60.9MB/s]
3%|2 | 86.1M/2.91G [00:01<00:56, 53.7MB/s]
3%|3 | 94.8M/2.91G [00:01<00:54, 55.8MB/s]
3%|3 | 100M/2.91G [00:01<00:53, 56.3MB/s]
4%|3 | 111M/2.91G [00:02<01:02, 47.8MB/s]
4%|3 | 116M/2.91G [00:02<01:05, 45.7MB/s]
4%|4 | 127M/2.91G [00:02<00:52, 57.4MB/s]
4%|4 | 133M/2.91G [00:02<01:09, 43.1MB/s]
5%|4 | 144M/2.91G [00:02<01:00, 49.1MB/s]
5%|4 | 149M/2.91G [00:03<01:01, 48.0MB/s]
5%|5 | 159M/2.91G [00:03<00:57, 51.6MB/s]
6%|5 | 164M/2.91G [00:03<01:03, 46.5MB/s]
6%|5 | 176M/2.91G [00:03<00:52, 55.6MB/s]
6%|6 | 191M/2.91G [00:03<00:47, 61.6MB/s]
7%|6 | 197M/2.91G [00:03<00:51, 56.9MB/s]
7%|6 | 208M/2.91G [00:03<00:42, 68.7MB/s]
7%|7 | 215M/2.91G [00:04<00:44, 64.8MB/s]
8%|7 | 224M/2.91G [00:04<00:48, 59.3MB/s]
8%|7 | 230M/2.91G [00:04<00:59, 48.3MB/s]
8%|8 | 240M/2.91G [00:04<01:01, 46.7MB/s]
9%|8 | 255M/2.91G [00:04<00:48, 58.4MB/s]
9%|8 | 261M/2.91G [00:05<00:53, 53.0MB/s]
9%|9 | 272M/2.91G [00:05<00:44, 64.4MB/s]
9%|9 | 281M/2.91G [00:05<00:43, 65.1MB/s]
10%|9 | 288M/2.91G [00:05<00:46, 60.6MB/s]
10%|9 | 295M/2.91G [00:05<00:44, 62.6MB/s]
10%|# | 301M/2.91G [00:05<00:45, 61.1MB/s]
10%|# | 307M/2.91G [00:05<00:52, 53.5MB/s]
11%|# | 319M/2.91G [00:06<00:45, 61.7MB/s]
11%|# | 325M/2.91G [00:06<00:54, 50.7MB/s]
11%|#1 | 333M/2.91G [00:06<00:51, 54.3MB/s]
11%|#1 | 339M/2.91G [00:06<01:02, 44.3MB/s]
12%|#1 | 352M/2.91G [00:06<00:49, 55.6MB/s]
12%|#1 | 357M/2.91G [00:06<00:54, 50.1MB/s]
12%|#2 | 367M/2.91G [00:07<00:50, 54.2MB/s]
12%|#2 | 372M/2.91G [00:07<00:59, 46.2MB/s]
13%|#2 | 378M/2.91G [00:07<00:59, 45.6MB/s]
13%|#2 | 383M/2.91G [00:07<01:02, 43.8MB/s]
13%|#2 | 387M/2.91G [00:07<01:09, 38.9MB/s]
13%|#3 | 400M/2.91G [00:07<00:46, 58.5MB/s]
14%|#3 | 415M/2.91G [00:07<00:40, 66.7MB/s]
14%|#4 | 421M/2.91G [00:08<00:39, 67.1MB/s]
14%|#4 | 432M/2.91G [00:08<00:35, 75.7MB/s]
15%|#4 | 439M/2.91G [00:08<00:44, 59.5MB/s]
15%|#5 | 448M/2.91G [00:08<00:43, 60.7MB/s]
16%|#5 | 464M/2.91G [00:08<00:37, 69.5MB/s]
16%|#6 | 479M/2.91G [00:08<00:35, 73.2MB/s]
16%|#6 | 486M/2.91G [00:09<00:39, 65.6MB/s]
17%|#6 | 496M/2.91G [00:09<00:37, 69.1MB/s]
17%|#7 | 512M/2.91G [00:09<00:32, 80.8MB/s]
18%|#7 | 527M/2.91G [00:09<00:27, 92.2MB/s]
18%|#7 | 536M/2.91G [00:09<00:29, 86.7MB/s]
18%|#8 | 544M/2.91G [00:09<00:42, 60.2MB/s]
19%|#8 | 559M/2.91G [00:10<00:34, 74.6MB/s]
19%|#9 | 567M/2.91G [00:10<00:37, 67.9MB/s]
19%|#9 | 575M/2.91G [00:10<00:38, 65.5MB/s]
20%|#9 | 582M/2.91G [00:10<00:48, 52.3MB/s]
20%|#9 | 592M/2.91G [00:10<00:43, 57.5MB/s]
20%|## | 605M/2.91G [00:10<00:33, 73.6MB/s]
21%|## | 613M/2.91G [00:11<00:42, 58.6MB/s]
21%|##1 | 639M/2.91G [00:11<00:30, 81.0MB/s]
22%|##1 | 647M/2.91G [00:11<00:33, 73.0MB/s]
22%|##2 | 656M/2.91G [00:11<00:35, 68.4MB/s]
23%|##2 | 672M/2.91G [00:11<00:31, 75.9MB/s]
23%|##2 | 679M/2.91G [00:11<00:34, 69.8MB/s]
23%|##3 | 688M/2.91G [00:12<00:34, 70.4MB/s]
23%|##3 | 694M/2.91G [00:12<00:37, 63.6MB/s]
24%|##3 | 704M/2.91G [00:12<00:33, 71.1MB/s]
24%|##4 | 719M/2.91G [00:12<00:27, 85.2MB/s]
24%|##4 | 727M/2.91G [00:12<00:38, 61.1MB/s]
25%|##4 | 734M/2.91G [00:12<00:43, 54.3MB/s]
25%|##4 | 740M/2.91G [00:13<00:46, 50.1MB/s]
25%|##5 | 745M/2.91G [00:13<00:52, 44.5MB/s]
25%|##5 | 752M/2.91G [00:13<00:49, 47.0MB/s]
26%|##5 | 768M/2.91G [00:13<00:32, 71.0MB/s]
26%|##6 | 776M/2.91G [00:13<00:42, 54.5MB/s]
26%|##6 | 783M/2.91G [00:13<00:44, 51.6MB/s]
26%|##6 | 789M/2.91G [00:13<00:42, 53.5MB/s]
27%|##6 | 799M/2.91G [00:14<00:37, 60.3MB/s]
27%|##7 | 805M/2.91G [00:14<00:44, 50.7MB/s]
27%|##7 | 815M/2.91G [00:14<00:40, 56.2MB/s]
28%|##7 | 821M/2.91G [00:14<00:41, 53.9MB/s]
28%|##7 | 832M/2.91G [00:14<00:37, 60.4MB/s]
28%|##8 | 848M/2.91G [00:14<00:31, 71.9MB/s]
29%|##8 | 863M/2.91G [00:15<00:27, 81.9MB/s]
29%|##9 | 871M/2.91G [00:15<00:30, 71.6MB/s]
30%|##9 | 880M/2.91G [00:15<00:33, 65.0MB/s]
30%|### | 896M/2.91G [00:15<00:31, 69.9MB/s]
30%|### | 902M/2.91G [00:15<00:36, 60.0MB/s]
31%|### | 912M/2.91G [00:15<00:35, 60.6MB/s]
31%|### | 918M/2.91G [00:16<00:39, 55.4MB/s]
31%|###1 | 928M/2.91G [00:16<00:37, 56.7MB/s]
32%|###1 | 944M/2.91G [00:16<00:32, 65.4MB/s]
32%|###2 | 960M/2.91G [00:16<00:25, 82.0MB/s]
32%|###2 | 968M/2.91G [00:16<00:27, 77.2MB/s]
33%|###2 | 976M/2.91G [00:17<00:36, 57.7MB/s]
33%|###2 | 982M/2.91G [00:17<00:35, 59.1MB/s]
33%|###3 | 992M/2.91G [00:17<00:33, 61.4MB/s]
34%|###3 | 0.98G/2.91G [00:17<00:27, 74.9MB/s]
34%|###4 | 0.99G/2.91G [00:17<00:28, 71.5MB/s]
34%|###4 | 1.00G/2.91G [00:17<00:34, 60.1MB/s]
35%|###4 | 1.00G/2.91G [00:17<00:41, 49.4MB/s]
35%|###4 | 1.02G/2.91G [00:18<00:36, 55.5MB/s]
35%|###5 | 1.02G/2.91G [00:18<00:46, 43.6MB/s]
35%|###5 | 1.03G/2.91G [00:18<00:44, 45.4MB/s]
36%|###5 | 1.03G/2.91G [00:18<00:49, 41.1MB/s]
36%|###5 | 1.04G/2.91G [00:18<00:51, 39.1MB/s]
36%|###5 | 1.04G/2.91G [00:18<00:51, 38.7MB/s]
36%|###5 | 1.05G/2.91G [00:19<00:55, 35.8MB/s]
36%|###6 | 1.05G/2.91G [00:19<01:01, 32.3MB/s]
37%|###6 | 1.06G/2.91G [00:19<00:42, 47.2MB/s]
37%|###6 | 1.07G/2.91G [00:19<00:43, 45.7MB/s]
37%|###6 | 1.07G/2.91G [00:19<00:35, 55.0MB/s]
37%|###7 | 1.08G/2.91G [00:19<00:47, 41.0MB/s]
38%|###7 | 1.09G/2.91G [00:20<00:33, 58.5MB/s]
38%|###7 | 1.10G/2.91G [00:20<00:33, 58.1MB/s]
38%|###8 | 1.11G/2.91G [00:20<00:39, 49.4MB/s]
38%|###8 | 1.11G/2.91G [00:20<00:40, 47.5MB/s]
39%|###8 | 1.12G/2.91G [00:20<00:33, 56.6MB/s]
39%|###8 | 1.13G/2.91G [00:20<00:36, 53.1MB/s]
39%|###9 | 1.14G/2.91G [00:21<00:35, 53.6MB/s]
39%|###9 | 1.15G/2.91G [00:21<00:37, 51.0MB/s]
40%|###9 | 1.15G/2.91G [00:21<00:31, 59.7MB/s]
40%|###9 | 1.16G/2.91G [00:21<00:40, 46.7MB/s]
40%|#### | 1.17G/2.91G [00:21<00:31, 58.7MB/s]
41%|#### | 1.19G/2.91G [00:21<00:27, 68.5MB/s]
41%|#### | 1.19G/2.91G [00:21<00:29, 62.4MB/s]
42%|####1 | 1.21G/2.91G [00:22<00:19, 92.4MB/s]
42%|####1 | 1.22G/2.91G [00:22<00:25, 71.4MB/s]
42%|####2 | 1.23G/2.91G [00:22<00:22, 81.7MB/s]
43%|####2 | 1.25G/2.91G [00:22<00:19, 89.3MB/s]
43%|####3 | 1.26G/2.91G [00:22<00:20, 88.3MB/s]
44%|####3 | 1.27G/2.91G [00:22<00:25, 68.7MB/s]
44%|####4 | 1.28G/2.91G [00:23<00:20, 83.7MB/s]
45%|####4 | 1.30G/2.91G [00:23<00:17, 99.1MB/s]
45%|####4 | 1.31G/2.91G [00:23<00:18, 95.2MB/s]
45%|####5 | 1.32G/2.91G [00:23<00:24, 70.0MB/s]
45%|####5 | 1.32G/2.91G [00:23<00:39, 43.3MB/s]
46%|####5 | 1.33G/2.91G [00:24<00:32, 52.5MB/s]
46%|####6 | 1.34G/2.91G [00:24<00:30, 54.8MB/s]
47%|####6 | 1.36G/2.91G [00:24<00:25, 65.1MB/s]
47%|####6 | 1.37G/2.91G [00:24<00:30, 53.7MB/s]
47%|####7 | 1.37G/2.91G [00:24<00:28, 58.2MB/s]
48%|####7 | 1.39G/2.91G [00:24<00:22, 72.0MB/s]
48%|####8 | 1.40G/2.91G [00:25<00:26, 61.1MB/s]
48%|####8 | 1.41G/2.91G [00:25<00:27, 57.7MB/s]
49%|####8 | 1.42G/2.91G [00:25<00:24, 65.8MB/s]
49%|####9 | 1.43G/2.91G [00:25<00:26, 59.7MB/s]
49%|####9 | 1.44G/2.91G [00:25<00:26, 59.3MB/s]
50%|####9 | 1.44G/2.91G [00:25<00:27, 56.3MB/s]
50%|####9 | 1.45G/2.91G [00:26<00:26, 59.6MB/s]
50%|##### | 1.47G/2.91G [00:26<00:20, 73.7MB/s]
51%|##### | 1.48G/2.91G [00:26<00:23, 65.1MB/s]
51%|##### | 1.48G/2.91G [00:26<00:21, 69.8MB/s]
51%|#####1 | 1.49G/2.91G [00:26<00:26, 57.6MB/s]
51%|#####1 | 1.50G/2.91G [00:26<00:26, 58.1MB/s]
52%|#####1 | 1.50G/2.91G [00:27<00:38, 39.3MB/s]
52%|#####2 | 1.51G/2.91G [00:27<00:27, 53.7MB/s]
52%|#####2 | 1.52G/2.91G [00:27<00:39, 37.9MB/s]
53%|#####2 | 1.53G/2.91G [00:27<00:31, 46.7MB/s]
53%|#####3 | 1.55G/2.91G [00:27<00:25, 56.8MB/s]
53%|#####3 | 1.55G/2.91G [00:28<00:27, 53.2MB/s]
54%|#####3 | 1.56G/2.91G [00:28<00:24, 59.3MB/s]
54%|#####3 | 1.57G/2.91G [00:28<00:24, 58.0MB/s]
54%|#####4 | 1.58G/2.91G [00:28<00:22, 63.5MB/s]
54%|#####4 | 1.58G/2.91G [00:28<00:25, 55.6MB/s]
55%|#####4 | 1.59G/2.91G [00:28<00:25, 55.3MB/s]
55%|#####5 | 1.61G/2.91G [00:28<00:18, 74.6MB/s]
56%|#####5 | 1.62G/2.91G [00:29<00:19, 73.1MB/s]
56%|#####5 | 1.62G/2.91G [00:29<00:21, 64.6MB/s]
56%|#####6 | 1.64G/2.91G [00:29<00:18, 75.3MB/s]
57%|#####6 | 1.65G/2.91G [00:29<00:24, 56.3MB/s]
57%|#####6 | 1.66G/2.91G [00:29<00:23, 57.0MB/s]
57%|#####7 | 1.67G/2.91G [00:30<00:20, 65.5MB/s]
58%|#####7 | 1.68G/2.91G [00:30<00:21, 62.1MB/s]
58%|#####7 | 1.69G/2.91G [00:30<00:19, 66.6MB/s]
58%|#####8 | 1.69G/2.91G [00:30<00:18, 69.1MB/s]
58%|#####8 | 1.70G/2.91G [00:30<00:17, 74.9MB/s]
59%|#####8 | 1.71G/2.91G [00:30<00:17, 72.9MB/s]
59%|#####9 | 1.72G/2.91G [00:30<00:18, 69.3MB/s]
59%|#####9 | 1.72G/2.91G [00:30<00:23, 55.2MB/s]
60%|#####9 | 1.73G/2.91G [00:31<00:20, 60.7MB/s]
60%|#####9 | 1.74G/2.91G [00:31<00:20, 62.4MB/s]
60%|###### | 1.75G/2.91G [00:31<00:26, 47.0MB/s]
60%|###### | 1.75G/2.91G [00:31<00:36, 34.1MB/s]
61%|###### | 1.76G/2.91G [00:31<00:26, 46.6MB/s]
61%|###### | 1.77G/2.91G [00:32<00:29, 41.4MB/s]
61%|###### | 1.77G/2.91G [00:32<00:29, 41.7MB/s]
61%|######1 | 1.78G/2.91G [00:32<00:23, 51.8MB/s]
62%|######1 | 1.79G/2.91G [00:32<00:19, 60.2MB/s]
62%|######1 | 1.80G/2.91G [00:32<00:17, 68.1MB/s]
62%|######2 | 1.82G/2.91G [00:32<00:13, 84.7MB/s]
63%|######2 | 1.83G/2.91G [00:32<00:14, 82.3MB/s]
63%|######3 | 1.83G/2.91G [00:32<00:15, 77.0MB/s]
64%|######3 | 1.86G/2.91G [00:33<00:09, 120MB/s]
64%|######4 | 1.87G/2.91G [00:33<00:12, 87.8MB/s]
65%|######4 | 1.88G/2.91G [00:33<00:18, 61.0MB/s]
65%|######4 | 1.89G/2.91G [00:33<00:16, 65.0MB/s]
65%|######5 | 1.90G/2.91G [00:33<00:16, 65.3MB/s]
65%|######5 | 1.91G/2.91G [00:34<00:15, 68.7MB/s]
66%|######5 | 1.91G/2.91G [00:34<00:15, 69.8MB/s]
66%|######6 | 1.92G/2.91G [00:34<00:17, 61.2MB/s]
66%|######6 | 1.93G/2.91G [00:34<00:15, 66.0MB/s]
67%|######6 | 1.94G/2.91G [00:34<00:16, 63.8MB/s]
67%|######7 | 1.95G/2.91G [00:34<00:13, 76.1MB/s]
67%|######7 | 1.96G/2.91G [00:34<00:15, 68.0MB/s]
68%|######7 | 1.97G/2.91G [00:35<00:20, 50.0MB/s]
68%|######7 | 1.97G/2.91G [00:35<00:19, 51.5MB/s]
68%|######8 | 1.98G/2.91G [00:35<00:15, 65.9MB/s]
68%|######8 | 1.99G/2.91G [00:35<00:18, 52.8MB/s]
69%|######8 | 2.00G/2.91G [00:35<00:19, 50.9MB/s]
69%|######8 | 2.00G/2.91G [00:35<00:20, 46.8MB/s]
69%|######9 | 2.01G/2.91G [00:36<00:18, 52.0MB/s]
69%|######9 | 2.02G/2.91G [00:36<00:18, 52.1MB/s]
70%|######9 | 2.02G/2.91G [00:36<00:21, 45.0MB/s]
70%|######9 | 2.03G/2.91G [00:36<00:21, 43.4MB/s]
70%|######9 | 2.03G/2.91G [00:36<00:31, 30.4MB/s]
70%|####### | 2.05G/2.91G [00:37<00:21, 43.1MB/s]
71%|####### | 2.06G/2.91G [00:37<00:19, 46.9MB/s]
71%|####### | 2.06G/2.91G [00:37<00:18, 48.8MB/s]
71%|#######1 | 2.07G/2.91G [00:37<00:18, 48.3MB/s]
71%|#######1 | 2.08G/2.91G [00:37<00:14, 60.4MB/s]
72%|#######1 | 2.08G/2.91G [00:37<00:14, 59.9MB/s]
72%|#######1 | 2.09G/2.91G [00:37<00:14, 60.8MB/s]
72%|#######2 | 2.10G/2.91G [00:37<00:15, 55.4MB/s]
72%|#######2 | 2.11G/2.91G [00:38<00:13, 65.4MB/s]
73%|#######2 | 2.12G/2.91G [00:38<00:11, 71.1MB/s]
73%|#######2 | 2.12G/2.91G [00:38<00:12, 68.2MB/s]
73%|#######3 | 2.13G/2.91G [00:38<00:13, 60.5MB/s]
74%|#######3 | 2.14G/2.91G [00:38<00:12, 66.5MB/s]
74%|#######3 | 2.15G/2.91G [00:38<00:13, 59.9MB/s]
74%|#######4 | 2.16G/2.91G [00:39<00:19, 42.5MB/s]
74%|#######4 | 2.16G/2.91G [00:39<00:16, 48.4MB/s]
74%|#######4 | 2.17G/2.91G [00:39<00:17, 45.1MB/s]
75%|#######4 | 2.17G/2.91G [00:39<00:20, 38.8MB/s]
75%|#######4 | 2.18G/2.91G [00:39<00:15, 51.2MB/s]
75%|#######5 | 2.19G/2.91G [00:39<00:17, 44.7MB/s]
76%|#######5 | 2.20G/2.91G [00:39<00:11, 64.6MB/s]
76%|#######5 | 2.21G/2.91G [00:40<00:12, 59.0MB/s]
76%|#######6 | 2.22G/2.91G [00:40<00:14, 52.1MB/s]
77%|#######6 | 2.23G/2.91G [00:40<00:12, 56.0MB/s]
77%|#######6 | 2.24G/2.91G [00:40<00:14, 49.6MB/s]
77%|#######7 | 2.25G/2.91G [00:40<00:13, 52.8MB/s]
77%|#######7 | 2.25G/2.91G [00:41<00:14, 50.4MB/s]
78%|#######7 | 2.26G/2.91G [00:41<00:13, 51.6MB/s]
78%|#######7 | 2.27G/2.91G [00:41<00:14, 46.4MB/s]
78%|#######8 | 2.28G/2.91G [00:41<00:12, 53.9MB/s]
79%|#######8 | 2.30G/2.91G [00:41<00:09, 66.2MB/s]
79%|#######9 | 2.30G/2.91G [00:41<00:09, 65.7MB/s]
79%|#######9 | 2.31G/2.91G [00:42<00:09, 68.8MB/s]
80%|######## | 2.33G/2.91G [00:42<00:09, 68.8MB/s]
81%|######## | 2.34G/2.91G [00:42<00:08, 75.8MB/s]
81%|######## | 2.35G/2.91G [00:42<00:09, 62.4MB/s]
81%|########1 | 2.36G/2.91G [00:42<00:09, 65.2MB/s]
81%|########1 | 2.36G/2.91G [00:43<00:12, 46.9MB/s]
82%|########1 | 2.37G/2.91G [00:43<00:11, 52.0MB/s]
82%|########1 | 2.38G/2.91G [00:43<00:10, 56.2MB/s]
82%|########2 | 2.39G/2.91G [00:43<00:08, 63.4MB/s]
82%|########2 | 2.40G/2.91G [00:43<00:09, 60.4MB/s]
83%|########2 | 2.40G/2.91G [00:43<00:11, 48.3MB/s]
83%|########2 | 2.41G/2.91G [00:43<00:09, 54.8MB/s]
83%|########3 | 2.42G/2.91G [00:44<00:10, 51.3MB/s]
84%|########3 | 2.43G/2.91G [00:44<00:07, 68.0MB/s]
84%|########3 | 2.44G/2.91G [00:44<00:08, 58.7MB/s]
84%|########4 | 2.45G/2.91G [00:44<00:07, 67.3MB/s]
85%|########4 | 2.46G/2.91G [00:44<00:07, 62.3MB/s]
85%|########4 | 2.47G/2.91G [00:44<00:07, 62.2MB/s]
85%|########5 | 2.48G/2.91G [00:45<00:06, 76.1MB/s]
86%|########5 | 2.49G/2.91G [00:45<00:07, 62.2MB/s]
86%|########5 | 2.50G/2.91G [00:45<00:06, 64.9MB/s]
86%|########6 | 2.51G/2.91G [00:45<00:06, 67.0MB/s]
87%|########6 | 2.52G/2.91G [00:45<00:06, 63.1MB/s]
87%|########6 | 2.53G/2.91G [00:45<00:06, 65.7MB/s]
87%|########7 | 2.54G/2.91G [00:46<00:07, 54.4MB/s]
87%|########7 | 2.55G/2.91G [00:46<00:07, 53.8MB/s]
88%|########7 | 2.55G/2.91G [00:46<00:07, 52.0MB/s]
88%|########8 | 2.56G/2.91G [00:46<00:05, 64.6MB/s]
88%|########8 | 2.57G/2.91G [00:46<00:06, 61.2MB/s]
89%|########8 | 2.58G/2.91G [00:46<00:06, 53.1MB/s]
89%|########8 | 2.58G/2.91G [00:47<00:07, 46.6MB/s]
89%|########9 | 2.59G/2.91G [00:47<00:06, 51.5MB/s]
89%|########9 | 2.60G/2.91G [00:47<00:06, 51.7MB/s]
90%|########9 | 2.61G/2.91G [00:47<00:05, 57.4MB/s]
90%|########9 | 2.61G/2.91G [00:47<00:05, 57.1MB/s]
90%|######### | 2.62G/2.91G [00:47<00:04, 64.4MB/s]
90%|######### | 2.63G/2.91G [00:47<00:05, 59.1MB/s]
91%|######### | 2.64G/2.91G [00:48<00:05, 57.8MB/s]
91%|######### | 2.64G/2.91G [00:48<00:05, 52.4MB/s]
91%|#########1| 2.66G/2.91G [00:48<00:04, 59.6MB/s]
92%|#########1| 2.67G/2.91G [00:48<00:04, 61.7MB/s]
92%|#########1| 2.67G/2.91G [00:48<00:04, 53.1MB/s]
92%|#########2| 2.69G/2.91G [00:48<00:04, 54.5MB/s]
93%|#########2| 2.69G/2.91G [00:49<00:04, 48.2MB/s]
93%|#########2| 2.70G/2.91G [00:49<00:04, 54.6MB/s]
93%|#########3| 2.71G/2.91G [00:49<00:04, 52.1MB/s]
93%|#########3| 2.72G/2.91G [00:49<00:03, 51.8MB/s]
94%|#########3| 2.73G/2.91G [00:49<00:02, 68.9MB/s]
94%|#########4| 2.74G/2.91G [00:50<00:03, 54.5MB/s]
94%|#########4| 2.75G/2.91G [00:50<00:03, 50.0MB/s]
95%|#########4| 2.75G/2.91G [00:50<00:03, 42.8MB/s]
95%|#########5| 2.77G/2.91G [00:50<00:02, 58.1MB/s]
95%|#########5| 2.78G/2.91G [00:50<00:02, 69.5MB/s]
96%|#########5| 2.78G/2.91G [00:50<00:02, 64.1MB/s]
96%|#########5| 2.79G/2.91G [00:50<00:02, 51.4MB/s]
96%|#########6| 2.80G/2.91G [00:51<00:02, 49.4MB/s]
96%|#########6| 2.80G/2.91G [00:51<00:02, 49.5MB/s]
97%|#########6| 2.81G/2.91G [00:51<00:02, 48.4MB/s]
97%|#########6| 2.82G/2.91G [00:51<00:02, 47.0MB/s]
97%|#########6| 2.82G/2.91G [00:51<00:02, 46.0MB/s]
97%|#########7| 2.83G/2.91G [00:51<00:01, 50.6MB/s]
97%|#########7| 2.83G/2.91G [00:51<00:01, 42.5MB/s]
98%|#########7| 2.84G/2.91G [00:52<00:01, 40.5MB/s]
98%|#########7| 2.84G/2.91G [00:52<00:01, 35.8MB/s]
98%|#########8| 2.86G/2.91G [00:52<00:01, 49.9MB/s]
98%|#########8| 2.86G/2.91G [00:52<00:01, 45.5MB/s]
99%|#########8| 2.87G/2.91G [00:53<00:00, 43.2MB/s]
99%|#########8| 2.88G/2.91G [00:53<00:00, 39.7MB/s]
99%|#########9| 2.89G/2.91G [00:53<00:00, 47.7MB/s]
99%|#########9| 2.89G/2.91G [00:53<00:00, 41.7MB/s]
100%|#########9| 2.91G/2.91G [00:53<00:00, 54.2MB/s]
100%|##########| 2.91G/2.91G [00:53<00:00, 58.1MB/s]
PretrainedFiles(lexicon='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt', tokens='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt', lm='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin')
Construct Decoders¶
In this tutorial, we construct both a beam search decoder and a greedy decoder for comparison.
Beam Search Decoder¶
The decoder can be constructed using the factory function
ctc_decoder()
.
In addition to the previously mentioned components, it also takes in various beam
search decoding parameters and token/word parameters.
This decoder can also be run without a language model by passing in None into the lm parameter.
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=3,
beam_size=1500,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
Greedy Decoder¶
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> List[str]:
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
return joined.replace("|", " ").strip().split()
greedy_decoder = GreedyCTCDecoder(tokens)
Run Inference¶
Now that we have the data, acoustic model, and decoder, we can perform
inference. The output of the beam search decoder is of type
CTCHypothesis
, consisting of the
predicted token IDs, corresponding words (if a lexicon is provided), hypothesis score,
and timesteps corresponding to the token IDs. Recall the transcript corresponding to the
waveform is
actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = actual_transcript.split()
emission, _ = acoustic_model(waveform)
The greedy decoder gives the following result.
greedy_result = greedy_decoder(emission[0])
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")
Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid
WER: 0.38095238095238093
Using the beam search decoder:
beam_search_result = beam_search_decoder(emission)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
actual_transcript
)
print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer}")
Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said
WER: 0.047619047619047616
Note
The words
field of the output hypotheses will be empty if no lexicon
is provided to the decoder. To retrieve a transcript with lexicon-free
decoding, you can perform the following to retrieve the token indices,
convert them to original tokens, then join them together.
tokens_str = "".join(beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens))
transcript = " ".join(tokens_str.split("|"))
We see that the transcript with the lexicon-constrained beam search decoder produces a more accurate result consisting of real words, while the greedy decoder can predict incorrectly spelled words like “affrayd” and “shoktd”.
Timestep Alignments¶
Recall that one of the components of the resulting Hypotheses is timesteps corresponding to the token IDs.
timesteps = beam_search_result[0][0].timesteps
predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)
print(predicted_tokens, len(predicted_tokens))
print(timesteps, timesteps.shape[0])
['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99
tensor([ 0, 31, 33, 36, 39, 41, 42, 44, 46, 48, 49, 52, 54, 58,
64, 66, 69, 73, 74, 76, 80, 82, 84, 86, 88, 94, 97, 107,
111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157,
159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191,
193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250,
251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289,
292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331,
353], dtype=torch.int32) 99
Below, we visualize the token timestep alignments relative to the original waveform.
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
Beam Search Decoder Parameters¶
In this section, we go a little bit more in depth about some different
parameters and tradeoffs. For the full list of customizable parameters,
please refer to the
documentation
.
Helper Function¶
def print_decoded(decoder, emission, param, param_value):
start_time = time.monotonic()
result = decoder(emission)
decode_time = time.monotonic() - start_time
transcript = " ".join(result[0][0].words).lower().strip()
score = result[0][0].score
print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)")
nbest¶
This parameter indicates the number of best hypotheses to return, which
is a property that is not possible with the greedy decoder. For
instance, by setting nbest=3
when constructing the beam search
decoder earlier, we can now access the hypotheses with the top 3 scores.
for i in range(3):
transcript = " ".join(beam_search_result[0][i].words).strip()
score = beam_search_result[0][i].score
print(f"{transcript} (score: {score})")
i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.8242175269093)
i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.8584784734217)
i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.0158622860877)
beam size¶
The beam_size
parameter determines the maximum number of best
hypotheses to hold after each decoding step. Using larger beam sizes
allows for exploring a larger range of possible hypotheses which can
produce hypotheses with higher scores, but it is computationally more
expensive and does not provide additional gains beyond a certain point.
In the example below, we see improvement in decoding quality as we increase beam size from 1 to 5 to 50, but notice how using a beam size of 500 provides the same output as beam size 50 while increase the computation time.
beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size=beam_size,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size", beam_size)
beam size 1 : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.2507 secs)
beam size 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0656 secs)
beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2558 secs)
beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.7189 secs)
beam size token¶
The beam_size_token
parameter corresponds to the number of tokens to
consider for expanding each hypothesis at the decoding step. Exploring a
larger number of next possible tokens increases the range of potential
hypotheses at the cost of computation.
num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token)
beam size token 1 : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.2628 secs)
beam size token 5 : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.2105 secs)
beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.2769 secs)
beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3071 secs)
beam threshold¶
The beam_threshold
parameter is used to prune the stored hypotheses
set at each decoding step, removing hypotheses whose scores are greater
than beam_threshold
away from the highest scoring hypothesis. There
is a balance between choosing smaller thresholds to prune more
hypotheses and reduce the search space, and choosing a large enough
threshold such that plausible hypotheses are not pruned.
beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold)
beam threshold 1 : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.0704 secs)
beam threshold 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.0744 secs)
beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2809 secs)
beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3192 secs)
language model weight¶
The lm_weight
parameter is the weight to assign to the language
model score which to accumulate with the acoustic model score for
determining the overall scores. Larger weights encourage the model to
predict next words based on the language model, while smaller weights
give more weight to the acoustic model score instead.
lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
lm_weight=lm_weight,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "lm weight", lm_weight)
lm weight 0 : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.3497 secs)
lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4497 secs)
lm weight 15 : was there in his was at some of what he said (score: 2918.99; 0.3410 secs)
additional parameters¶
Additional parameters that can be optimized include the following
word_score
: score to add when word finishesunk_score
: unknown word appearance score to addsil_score
: silence appearance score to addlog_add
: whether to use log add for lexicon Trie smearing
Total running time of the script: ( 3 minutes 0.742 seconds)