• Tutorials >
  • Speech Command Classification with torchaudio
Shortcuts

Speech Command Classification with torchaudio

This tutorial will show you how to correctly format an audio dataset and then train/test an audio classifier network on the dataset.

Colab has GPU option available. In the menu tabs, select “Runtime” then “Change runtime type”. In the pop-up that follows, you can choose GPU. After the change, your runtime should automatically restart (which means information from executed cells disappear).

First, let’s import the common torch packages such as torchaudio that can be installed by following the instructions on the website.

# Uncomment the line corresponding to your "runtime type" to run in Google Colab

# CPU:
# !pip install pydub torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# GPU:
# !pip install pydub torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

Let’s check if a CUDA GPU is available and select our device. Running the network on a GPU will greatly decrease the training/testing runtime.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda

Importing the Dataset

We use torchaudio to download and represent the dataset. Here we use SpeechCommands, which is a datasets of 35 commands spoken by different people. The dataset SPEECHCOMMANDS is a torch.utils.data.Dataset version of the dataset. In this dataset, all audio files are about 1 second long (and so about 16000 time frames long).

The actual loading and formatting steps happen when a data point is being accessed, and torchaudio takes care of converting the audio files to tensors. If one wants to load an audio file directly instead, torchaudio.load() can be used. It returns a tuple containing the newly created tensor along with the sampling frequency of the audio file (16kHz for SpeechCommands).

Going back to the dataset, here we create a subclass that splits it into standard training, validation, testing subsets.

from torchaudio.datasets import SPEECHCOMMANDS
import os


class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as fileobj:
                return [os.path.normpath(os.path.join(self._path, line.strip())) for line in fileobj]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]


# Create training and testing split of the data. We do not use validation in this tutorial.
train_set = SubsetSC("training")
test_set = SubsetSC("testing")

waveform, sample_rate, label, speaker_id, utterance_number = train_set[0]
  0%|          | 0.00/2.26G [00:00<?, ?B/s]
  1%|          | 13.1M/2.26G [00:00<00:17, 137MB/s]
  1%|1         | 30.4M/2.26G [00:00<00:14, 163MB/s]
  2%|2         | 47.5M/2.26G [00:00<00:13, 171MB/s]
  3%|2         | 64.8M/2.26G [00:00<00:13, 175MB/s]
  4%|3         | 82.0M/2.26G [00:00<00:13, 177MB/s]
  4%|4         | 98.9M/2.26G [00:00<00:13, 177MB/s]
  5%|5         | 116M/2.26G [00:00<00:12, 178MB/s]
  6%|5         | 133M/2.26G [00:00<00:12, 177MB/s]
  6%|6         | 150M/2.26G [00:00<00:12, 175MB/s]
  7%|7         | 167M/2.26G [00:01<00:12, 175MB/s]
  8%|7         | 184M/2.26G [00:01<00:12, 176MB/s]
  9%|8         | 201M/2.26G [00:01<00:12, 177MB/s]
  9%|9         | 218M/2.26G [00:01<00:12, 176MB/s]
 10%|#         | 236M/2.26G [00:01<00:12, 179MB/s]
 11%|#         | 254M/2.26G [00:01<00:11, 184MB/s]
 12%|#1        | 273M/2.26G [00:01<00:11, 188MB/s]
 13%|#2        | 291M/2.26G [00:01<00:11, 189MB/s]
 13%|#3        | 310M/2.26G [00:01<00:11, 191MB/s]
 14%|#4        | 329M/2.26G [00:01<00:10, 194MB/s]
 15%|#5        | 348M/2.26G [00:02<00:10, 195MB/s]
 16%|#5        | 367M/2.26G [00:02<00:10, 198MB/s]
 17%|#6        | 386M/2.26G [00:02<00:10, 199MB/s]
 18%|#7        | 406M/2.26G [00:02<00:10, 199MB/s]
 18%|#8        | 425M/2.26G [00:02<00:09, 199MB/s]
 19%|#9        | 444M/2.26G [00:02<00:09, 199MB/s]
 20%|#9        | 463M/2.26G [00:02<00:09, 199MB/s]
 21%|##        | 482M/2.26G [00:02<00:09, 197MB/s]
 22%|##1       | 501M/2.26G [00:02<00:09, 196MB/s]
 22%|##2       | 519M/2.26G [00:02<00:09, 195MB/s]
 23%|##3       | 538M/2.26G [00:03<00:09, 196MB/s]
 24%|##4       | 557M/2.26G [00:03<00:09, 196MB/s]
 25%|##4       | 576M/2.26G [00:03<00:09, 192MB/s]
 26%|##5       | 594M/2.26G [00:03<00:09, 188MB/s]
 26%|##6       | 612M/2.26G [00:03<00:09, 188MB/s]
 27%|##7       | 630M/2.26G [00:03<00:09, 187MB/s]
 28%|##7       | 648M/2.26G [00:03<00:09, 188MB/s]
 29%|##8       | 668M/2.26G [00:03<00:08, 194MB/s]
 30%|##9       | 687M/2.26G [00:03<00:08, 196MB/s]
 30%|###       | 706M/2.26G [00:03<00:08, 197MB/s]
 31%|###1      | 725M/2.26G [00:04<00:08, 198MB/s]
 32%|###2      | 744M/2.26G [00:04<00:08, 198MB/s]
 33%|###2      | 763M/2.26G [00:04<00:08, 197MB/s]
 34%|###3      | 781M/2.26G [00:04<00:08, 197MB/s]
 35%|###4      | 801M/2.26G [00:04<00:08, 198MB/s]
 35%|###5      | 820M/2.26G [00:04<00:07, 198MB/s]
 36%|###6      | 838M/2.26G [00:04<00:07, 197MB/s]
 37%|###7      | 857M/2.26G [00:04<00:07, 197MB/s]
 38%|###7      | 876M/2.26G [00:04<00:07, 197MB/s]
 39%|###8      | 895M/2.26G [00:04<00:07, 197MB/s]
 39%|###9      | 914M/2.26G [00:05<00:07, 195MB/s]
 40%|####      | 933M/2.26G [00:05<00:07, 196MB/s]
 41%|####1     | 952M/2.26G [00:05<00:07, 198MB/s]
 42%|####1     | 971M/2.26G [00:05<00:07, 199MB/s]
 43%|####2     | 990M/2.26G [00:05<00:06, 200MB/s]
 44%|####3     | 0.99G/2.26G [00:05<00:07, 186MB/s]
 44%|####4     | 1.00G/2.26G [00:05<00:07, 190MB/s]
 45%|####5     | 1.02G/2.26G [00:05<00:06, 194MB/s]
 46%|####6     | 1.04G/2.26G [00:05<00:06, 195MB/s]
 47%|####6     | 1.06G/2.26G [00:05<00:06, 195MB/s]
 48%|####7     | 1.08G/2.26G [00:06<00:06, 197MB/s]
 48%|####8     | 1.10G/2.26G [00:06<00:06, 196MB/s]
 49%|####9     | 1.12G/2.26G [00:06<00:06, 198MB/s]
 50%|#####     | 1.13G/2.26G [00:06<00:06, 195MB/s]
 51%|#####     | 1.15G/2.26G [00:06<00:06, 190MB/s]
 52%|#####1    | 1.17G/2.26G [00:06<00:06, 183MB/s]
 52%|#####2    | 1.19G/2.26G [00:06<00:06, 182MB/s]
 53%|#####3    | 1.21G/2.26G [00:06<00:06, 186MB/s]
 54%|#####4    | 1.22G/2.26G [00:06<00:05, 190MB/s]
 55%|#####4    | 1.24G/2.26G [00:06<00:05, 188MB/s]
 56%|#####5    | 1.26G/2.26G [00:07<00:05, 191MB/s]
 56%|#####6    | 1.28G/2.26G [00:07<00:05, 192MB/s]
 57%|#####7    | 1.30G/2.26G [00:07<00:05, 192MB/s]
 58%|#####8    | 1.31G/2.26G [00:07<00:05, 187MB/s]
 59%|#####8    | 1.33G/2.26G [00:07<00:05, 175MB/s]
 60%|#####9    | 1.35G/2.26G [00:07<00:05, 175MB/s]
 60%|######    | 1.36G/2.26G [00:07<00:05, 172MB/s]
 61%|######1   | 1.38G/2.26G [00:07<00:05, 169MB/s]
 62%|######1   | 1.40G/2.26G [00:07<00:05, 167MB/s]
 62%|######2   | 1.41G/2.26G [00:08<00:05, 166MB/s]
 63%|######3   | 1.43G/2.26G [00:08<00:05, 168MB/s]
 64%|######3   | 1.44G/2.26G [00:08<00:05, 166MB/s]
 64%|######4   | 1.46G/2.26G [00:08<00:05, 166MB/s]
 65%|######5   | 1.47G/2.26G [00:08<00:06, 141MB/s]
 66%|######5   | 1.49G/2.26G [00:08<00:05, 149MB/s]
 67%|######6   | 1.51G/2.26G [00:08<00:05, 159MB/s]
 67%|######7   | 1.52G/2.26G [00:08<00:04, 165MB/s]
 68%|######8   | 1.54G/2.26G [00:08<00:04, 164MB/s]
 69%|######8   | 1.56G/2.26G [00:09<00:04, 164MB/s]
 69%|######9   | 1.57G/2.26G [00:09<00:04, 164MB/s]
 70%|#######   | 1.59G/2.26G [00:09<00:04, 166MB/s]
 71%|#######   | 1.60G/2.26G [00:09<00:04, 167MB/s]
 72%|#######1  | 1.62G/2.26G [00:09<00:04, 167MB/s]
 72%|#######2  | 1.63G/2.26G [00:09<00:04, 166MB/s]
 73%|#######2  | 1.65G/2.26G [00:09<00:03, 165MB/s]
 74%|#######3  | 1.67G/2.26G [00:09<00:03, 166MB/s]
 74%|#######4  | 1.68G/2.26G [00:09<00:03, 169MB/s]
 75%|#######5  | 1.70G/2.26G [00:09<00:03, 169MB/s]
 76%|#######5  | 1.71G/2.26G [00:10<00:03, 169MB/s]
 76%|#######6  | 1.73G/2.26G [00:10<00:03, 171MB/s]
 77%|#######7  | 1.75G/2.26G [00:10<00:03, 170MB/s]
 78%|#######7  | 1.76G/2.26G [00:10<00:03, 168MB/s]
 79%|#######8  | 1.78G/2.26G [00:10<00:03, 167MB/s]
 79%|#######9  | 1.79G/2.26G [00:10<00:03, 166MB/s]
 80%|#######9  | 1.81G/2.26G [00:10<00:02, 167MB/s]
 81%|########  | 1.82G/2.26G [00:10<00:02, 168MB/s]
 81%|########1 | 1.84G/2.26G [00:10<00:02, 169MB/s]
 82%|########2 | 1.86G/2.26G [00:10<00:02, 168MB/s]
 83%|########2 | 1.87G/2.26G [00:11<00:02, 169MB/s]
 83%|########3 | 1.89G/2.26G [00:11<00:02, 170MB/s]
 84%|########4 | 1.90G/2.26G [00:11<00:02, 171MB/s]
 85%|########4 | 1.92G/2.26G [00:11<00:02, 172MB/s]
 86%|########5 | 1.94G/2.26G [00:11<00:02, 171MB/s]
 86%|########6 | 1.95G/2.26G [00:11<00:01, 171MB/s]
 87%|########6 | 1.97G/2.26G [00:11<00:01, 169MB/s]
 88%|########7 | 1.98G/2.26G [00:11<00:01, 168MB/s]
 88%|########8 | 2.00G/2.26G [00:11<00:01, 168MB/s]
 89%|########9 | 2.01G/2.26G [00:11<00:01, 168MB/s]
 90%|########9 | 2.03G/2.26G [00:12<00:01, 170MB/s]
 90%|######### | 2.05G/2.26G [00:12<00:01, 170MB/s]
 91%|#########1| 2.06G/2.26G [00:12<00:01, 169MB/s]
 92%|#########1| 2.08G/2.26G [00:12<00:01, 168MB/s]
 93%|#########2| 2.09G/2.26G [00:12<00:01, 169MB/s]
 93%|#########3| 2.11G/2.26G [00:12<00:00, 171MB/s]
 94%|#########4| 2.13G/2.26G [00:12<00:00, 171MB/s]
 95%|#########4| 2.14G/2.26G [00:12<00:00, 169MB/s]
 95%|#########5| 2.16G/2.26G [00:12<00:00, 167MB/s]
 96%|#########6| 2.17G/2.26G [00:12<00:00, 167MB/s]
 97%|#########6| 2.19G/2.26G [00:13<00:00, 169MB/s]
 98%|#########7| 2.21G/2.26G [00:13<00:00, 169MB/s]
 98%|#########8| 2.22G/2.26G [00:13<00:00, 169MB/s]
 99%|#########8| 2.24G/2.26G [00:13<00:00, 170MB/s]
100%|#########9| 2.25G/2.26G [00:13<00:00, 146MB/s]
100%|##########| 2.26G/2.26G [00:13<00:00, 179MB/s]

A data point in the SPEECHCOMMANDS dataset is a tuple made of a waveform (the audio signal), the sample rate, the utterance (label), the ID of the speaker, the number of the utterance.

print("Shape of waveform: {}".format(waveform.size()))
print("Sample rate of waveform: {}".format(sample_rate))

plt.plot(waveform.t().numpy());
speech command classification with torchaudio tutorial
Shape of waveform: torch.Size([1, 16000])
Sample rate of waveform: 16000

[<matplotlib.lines.Line2D object at 0x7fcd4fcfd150>]

Let’s find the list of labels available in the dataset.

labels = sorted(list(set(datapoint[2] for datapoint in train_set)))
labels
['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']

The 35 audio labels are commands that are said by users. The first few files are people saying “marvin”.

waveform_first, *_ = train_set[0]
ipd.Audio(waveform_first.numpy(), rate=sample_rate)

waveform_second, *_ = train_set[1]
ipd.Audio(waveform_second.numpy(), rate=sample_rate)


The last file is someone saying “visual”.

waveform_last, *_ = train_set[-1]
ipd.Audio(waveform_last.numpy(), rate=sample_rate)


Formatting the Data

This is a good place to apply transformations to the data. For the waveform, we downsample the audio for faster processing without losing too much of the classification power.

We don’t need to apply other transformations here. It is common for some datasets though to have to reduce the number of channels (say from stereo to mono) by either taking the mean along the channel dimension, or simply keeping only one of the channels. Since SpeechCommands uses a single channel for audio, this is not needed here.

new_sample_rate = 8000
transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)
transformed = transform(waveform)

ipd.Audio(transformed.numpy(), rate=new_sample_rate)


We are encoding each word using its index in the list of labels.

def label_to_index(word):
    # Return the position of the word in labels
    return torch.tensor(labels.index(word))


def index_to_label(index):
    # Return the word corresponding to the index in labels
    # This is the inverse of label_to_index
    return labels[index]


word_start = "yes"
index = label_to_index(word_start)
word_recovered = index_to_label(index)

print(word_start, "-->", index, "-->", word_recovered)
yes --> tensor(33) --> yes

To turn a list of data point made of audio recordings and utterances into two batched tensors for the model, we implement a collate function which is used by the PyTorch DataLoader that allows us to iterate over a dataset by batches. Please see the documentation for more information about working with a collate function.

In the collate function, we also apply the resampling, and the text encoding.

def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
    return batch.permute(0, 2, 1)


def collate_fn(batch):

    # A data tuple has the form:
    # waveform, sample_rate, label, speaker_id, utterance_number

    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, _, label, *_ in batch:
        tensors += [waveform]
        targets += [label_to_index(label)]

    # Group the list of tensors into a batched tensor
    tensors = pad_sequence(tensors)
    targets = torch.stack(targets)

    return tensors, targets


batch_size = 256

if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

Define the Network

For this tutorial we will use a convolutional neural network to process the raw audio data. Usually more advanced transforms are applied to the audio data, however CNNs can be used to accurately process the raw data. The specific architecture is modeled after the M5 network architecture described in this paper. An important aspect of models processing raw audio data is the receptive field of their first layer’s filters. Our model’s first filter is length 80 so when processing audio sampled at 8kHz the receptive field is around 10ms (and at 4kHz, around 20 ms). This size is similar to speech processing applications that often use receptive fields ranging from 20ms to 40ms.

class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)


model = M5(n_input=transformed.shape[0], n_output=len(labels))
model.to(device)
print(model)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)
M5(
  (conv1): Conv1d(1, 32, kernel_size=(80,), stride=(16,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=35, bias=True)
)
Number of parameters: 26915

We will use the same optimization technique used in the paper, an Adam optimizer with weight decay set to 0.0001. At first, we will train with a learning rate of 0.01, but we will use a scheduler to decrease it to 0.001 during training after 20 epochs.

optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)  # reduce the learning after 20 epochs by a factor of 10

Training and Testing the Network

Now let’s define a training function that will feed our training data into the model and perform the backward pass and optimization steps. For training, the loss we will use is the negative log-likelihood. The network will then be tested after each epoch to see how the accuracy varies during the training.

def train(model, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        data = transform(data)
        output = model(data)

        # negative log-likelihood for a tensor of size (batch x 1 x n_output)
        loss = F.nll_loss(output.squeeze(), target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        losses.append(loss.item())

Now that we have a training function, we need to make one for testing the networks accuracy. We will set the model to eval() mode and then run inference on the test dataset. Calling eval() sets the training variable in all modules in the network to false. Certain layers like batch normalization and dropout layers behave differently during training so this step is crucial for getting correct results.

def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch):
    model.eval()
    correct = 0
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        data = transform(data)
        output = model(data)

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target)

        # update progress bar
        pbar.update(pbar_update)

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n")

Finally, we can train and test the network. We will train the network for ten epochs then reduce the learn rate and train for ten more epochs. The network will be tested after each epoch to see how the accuracy varies during the training.

log_interval = 20
n_epoch = 2

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []

# The transform needs to live on the same device as the model and the data.
transform = transform.to(device)
with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        train(model, epoch, log_interval)
        test(model, epoch)
        scheduler.step()

# Let's plot the training loss versus the number of iteration.
# plt.plot(losses);
# plt.title("training loss");
  0%|          | 0/2 [00:00<?, ?it/s]Train Epoch: 1 [0/84843 (0%)]      Loss: 3.813234

  0%|          | 0.0026666666666666666/2 [00:02<35:06, 1054.66s/it]
  0%|          | 0.005333333333333333/2 [00:02<15:14, 458.61s/it]
  1%|          | 0.010666666666666666/2 [00:03<06:24, 193.47s/it]
  1%|          | 0.016/2 [00:03<03:56, 119.08s/it]
  1%|1         | 0.021333333333333336/2 [00:03<02:47, 84.82s/it]
  1%|1         | 0.026666666666666672/2 [00:03<02:10, 66.28s/it]
  2%|1         | 0.03200000000000001/2 [00:03<01:49, 55.42s/it]
  2%|1         | 0.037333333333333336/2 [00:04<01:34, 48.32s/it]
  2%|2         | 0.042666666666666665/2 [00:04<01:26, 44.01s/it]
  2%|2         | 0.047999999999999994/2 [00:04<01:19, 40.72s/it]
  3%|2         | 0.05333333333333332/2 [00:04<01:14, 38.51s/it] Train Epoch: 1 [5120/84843 (6%)]        Loss: 3.041089

  3%|2         | 0.05866666666666665/2 [00:04<01:12, 37.28s/it]
  3%|3         | 0.06399999999999999/2 [00:04<01:09, 35.69s/it]
  3%|3         | 0.06933333333333332/2 [00:05<01:05, 34.18s/it]
  4%|3         | 0.07466666666666665/2 [00:05<01:04, 33.40s/it]
  4%|3         | 0.07999999999999997/2 [00:05<01:02, 32.65s/it]
  4%|4         | 0.0853333333333333/2 [00:05<01:01, 32.33s/it]
  5%|4         | 0.09066666666666663/2 [00:05<01:01, 32.00s/it]
  5%|4         | 0.09599999999999996/2 [00:05<01:00, 31.76s/it]
  5%|5         | 0.10133333333333329/2 [00:06<01:00, 31.79s/it]
  5%|5         | 0.10666666666666662/2 [00:06<00:59, 31.63s/it]Train Epoch: 1 [10240/84843 (12%)]       Loss: 2.553542

  6%|5         | 0.11199999999999995/2 [00:06<00:59, 31.61s/it]
  6%|5         | 0.11733333333333328/2 [00:06<00:59, 31.68s/it]
  6%|6         | 0.1226666666666666/2 [00:06<00:59, 31.80s/it]
  6%|6         | 0.12799999999999995/2 [00:06<00:59, 32.00s/it]
  7%|6         | 0.1333333333333333/2 [00:07<01:00, 32.18s/it]
  7%|6         | 0.13866666666666666/2 [00:07<01:00, 32.38s/it]
  7%|7         | 0.14400000000000002/2 [00:07<00:59, 31.99s/it]
  7%|7         | 0.14933333333333337/2 [00:07<00:59, 31.94s/it]
  8%|7         | 0.15466666666666673/2 [00:07<00:58, 31.82s/it]
  8%|8         | 0.1600000000000001/2 [00:07<00:58, 31.83s/it] Train Epoch: 1 [15360/84843 (18%)]       Loss: 2.254806

  8%|8         | 0.16533333333333344/2 [00:08<00:58, 31.74s/it]
  9%|8         | 0.1706666666666668/2 [00:08<00:57, 31.45s/it]
  9%|8         | 0.17600000000000016/2 [00:08<00:56, 31.23s/it]
  9%|9         | 0.1813333333333335/2 [00:08<00:56, 31.16s/it]
  9%|9         | 0.18666666666666687/2 [00:08<00:56, 31.21s/it]
 10%|9         | 0.19200000000000023/2 [00:08<00:56, 31.32s/it]
 10%|9         | 0.19733333333333358/2 [00:09<00:57, 31.75s/it]
 10%|#         | 0.20266666666666694/2 [00:09<00:56, 31.59s/it]
 10%|#         | 0.2080000000000003/2 [00:09<00:56, 31.46s/it]
 11%|#         | 0.21333333333333365/2 [00:09<00:56, 31.56s/it]Train Epoch: 1 [20480/84843 (24%)]       Loss: 2.029420

 11%|#         | 0.218666666666667/2 [00:09<00:56, 31.62s/it]
 11%|#1        | 0.22400000000000037/2 [00:10<00:58, 32.74s/it]
 11%|#1        | 0.22933333333333372/2 [00:10<00:59, 33.57s/it]
 12%|#1        | 0.23466666666666708/2 [00:10<00:59, 33.96s/it]
 12%|#2        | 0.24000000000000044/2 [00:10<01:00, 34.59s/it]
 12%|#2        | 0.2453333333333338/2 [00:10<01:00, 34.48s/it]
 13%|#2        | 0.25066666666666715/2 [00:10<00:58, 33.66s/it]
 13%|#2        | 0.25600000000000045/2 [00:11<00:58, 33.29s/it]
 13%|#3        | 0.26133333333333375/2 [00:11<00:57, 33.15s/it]
 13%|#3        | 0.26666666666666705/2 [00:11<00:56, 32.75s/it]Train Epoch: 1 [25600/84843 (30%)]       Loss: 1.780852

 14%|#3        | 0.27200000000000035/2 [00:11<00:55, 32.39s/it]
 14%|#3        | 0.27733333333333365/2 [00:11<00:55, 32.33s/it]
 14%|#4        | 0.28266666666666695/2 [00:11<00:55, 32.18s/it]
 14%|#4        | 0.28800000000000026/2 [00:12<00:54, 31.90s/it]
 15%|#4        | 0.29333333333333356/2 [00:12<00:54, 31.80s/it]
 15%|#4        | 0.29866666666666686/2 [00:12<00:53, 31.72s/it]
 15%|#5        | 0.30400000000000016/2 [00:12<00:53, 31.61s/it]
 15%|#5        | 0.30933333333333346/2 [00:12<00:53, 31.57s/it]
 16%|#5        | 0.31466666666666676/2 [00:12<00:53, 31.86s/it]
 16%|#6        | 0.32000000000000006/2 [00:13<00:54, 32.20s/it]Train Epoch: 1 [30720/84843 (36%)]       Loss: 1.574860

 16%|#6        | 0.32533333333333336/2 [00:13<00:53, 31.77s/it]
 17%|#6        | 0.33066666666666666/2 [00:13<00:52, 31.46s/it]
 17%|#6        | 0.33599999999999997/2 [00:13<00:52, 31.58s/it]
 17%|#7        | 0.34133333333333327/2 [00:13<00:52, 31.46s/it]
 17%|#7        | 0.34666666666666657/2 [00:13<00:52, 31.48s/it]
 18%|#7        | 0.35199999999999987/2 [00:14<00:52, 31.57s/it]
 18%|#7        | 0.35733333333333317/2 [00:14<00:51, 31.63s/it]
 18%|#8        | 0.36266666666666647/2 [00:14<00:51, 31.71s/it]
 18%|#8        | 0.36799999999999977/2 [00:14<00:51, 31.56s/it]
 19%|#8        | 0.3733333333333331/2 [00:14<00:51, 31.65s/it] Train Epoch: 1 [35840/84843 (42%)]       Loss: 1.676274

 19%|#8        | 0.3786666666666664/2 [00:14<00:51, 31.64s/it]
 19%|#9        | 0.3839999999999997/2 [00:15<00:50, 31.42s/it]
 19%|#9        | 0.389333333333333/2 [00:15<00:50, 31.18s/it]
 20%|#9        | 0.3946666666666663/2 [00:15<00:50, 31.24s/it]
 20%|#9        | 0.3999999999999996/2 [00:15<00:50, 31.39s/it]
 20%|##        | 0.4053333333333329/2 [00:15<00:50, 31.55s/it]
 21%|##        | 0.4106666666666662/2 [00:15<00:50, 31.57s/it]
 21%|##        | 0.4159999999999995/2 [00:16<00:49, 31.45s/it]
 21%|##1       | 0.4213333333333328/2 [00:16<00:49, 31.46s/it]
 21%|##1       | 0.4266666666666661/2 [00:16<00:49, 31.66s/it]Train Epoch: 1 [40960/84843 (48%)]        Loss: 1.567081

 22%|##1       | 0.4319999999999994/2 [00:16<00:50, 31.97s/it]
 22%|##1       | 0.4373333333333327/2 [00:16<00:49, 31.72s/it]
 22%|##2       | 0.442666666666666/2 [00:17<00:49, 31.71s/it]
 22%|##2       | 0.4479999999999993/2 [00:17<00:49, 31.63s/it]
 23%|##2       | 0.4533333333333326/2 [00:17<00:48, 31.67s/it]
 23%|##2       | 0.4586666666666659/2 [00:17<00:49, 31.90s/it]
 23%|##3       | 0.4639999999999992/2 [00:17<00:48, 31.70s/it]
 23%|##3       | 0.4693333333333325/2 [00:17<00:48, 31.73s/it]
 24%|##3       | 0.4746666666666658/2 [00:18<00:48, 31.60s/it]
 24%|##3       | 0.4799999999999991/2 [00:18<00:47, 31.46s/it]Train Epoch: 1 [46080/84843 (54%)]        Loss: 1.373785

 24%|##4       | 0.4853333333333324/2 [00:18<00:47, 31.47s/it]
 25%|##4       | 0.4906666666666657/2 [00:18<00:47, 31.38s/it]
 25%|##4       | 0.495999999999999/2 [00:18<00:47, 31.66s/it]
 25%|##5       | 0.5013333333333323/2 [00:18<00:48, 32.09s/it]
 25%|##5       | 0.5066666666666657/2 [00:19<00:47, 31.92s/it]
 26%|##5       | 0.5119999999999991/2 [00:19<00:47, 31.78s/it]
 26%|##5       | 0.5173333333333325/2 [00:19<00:46, 31.69s/it]
 26%|##6       | 0.522666666666666/2 [00:19<00:46, 31.56s/it]
 26%|##6       | 0.5279999999999994/2 [00:19<00:46, 31.48s/it]
 27%|##6       | 0.5333333333333328/2 [00:19<00:46, 31.37s/it]Train Epoch: 1 [51200/84843 (60%)]        Loss: 1.345279

 27%|##6       | 0.5386666666666662/2 [00:20<00:46, 31.53s/it]
 27%|##7       | 0.5439999999999996/2 [00:20<00:46, 31.67s/it]
 27%|##7       | 0.549333333333333/2 [00:20<00:45, 31.69s/it]
 28%|##7       | 0.5546666666666664/2 [00:20<00:45, 31.75s/it]
 28%|##7       | 0.5599999999999998/2 [00:20<00:45, 31.73s/it]
 28%|##8       | 0.5653333333333332/2 [00:20<00:45, 31.94s/it]
 29%|##8       | 0.5706666666666667/2 [00:21<00:46, 32.51s/it]
 29%|##8       | 0.5760000000000001/2 [00:21<00:45, 32.22s/it]
 29%|##9       | 0.5813333333333335/2 [00:21<00:45, 32.06s/it]
 29%|##9       | 0.5866666666666669/2 [00:21<00:45, 32.21s/it]Train Epoch: 1 [56320/84843 (66%)]        Loss: 1.342250

 30%|##9       | 0.5920000000000003/2 [00:21<00:45, 32.14s/it]
 30%|##9       | 0.5973333333333337/2 [00:21<00:44, 31.89s/it]
 30%|###       | 0.6026666666666671/2 [00:22<00:44, 31.83s/it]
 30%|###       | 0.6080000000000005/2 [00:22<00:44, 31.90s/it]
 31%|###       | 0.613333333333334/2 [00:22<00:44, 31.96s/it]
 31%|###       | 0.6186666666666674/2 [00:22<00:44, 31.91s/it]
 31%|###1      | 0.6240000000000008/2 [00:22<00:44, 32.04s/it]
 31%|###1      | 0.6293333333333342/2 [00:22<00:43, 31.95s/it]
 32%|###1      | 0.6346666666666676/2 [00:23<00:43, 31.67s/it]
 32%|###2      | 0.640000000000001/2 [00:23<00:42, 31.41s/it] Train Epoch: 1 [61440/84843 (72%)]        Loss: 1.286719

 32%|###2      | 0.6453333333333344/2 [00:23<00:42, 31.22s/it]
 33%|###2      | 0.6506666666666678/2 [00:23<00:42, 31.32s/it]
 33%|###2      | 0.6560000000000012/2 [00:23<00:42, 31.41s/it]
 33%|###3      | 0.6613333333333347/2 [00:23<00:42, 31.45s/it]
 33%|###3      | 0.6666666666666681/2 [00:24<00:41, 31.46s/it]
 34%|###3      | 0.6720000000000015/2 [00:24<00:42, 31.79s/it]
 34%|###3      | 0.6773333333333349/2 [00:24<00:42, 31.96s/it]
 34%|###4      | 0.6826666666666683/2 [00:24<00:42, 32.07s/it]
 34%|###4      | 0.6880000000000017/2 [00:24<00:41, 31.82s/it]
 35%|###4      | 0.6933333333333351/2 [00:24<00:41, 31.90s/it]Train Epoch: 1 [66560/84843 (78%)]        Loss: 1.136720

 35%|###4      | 0.6986666666666685/2 [00:25<00:41, 31.76s/it]
 35%|###5      | 0.704000000000002/2 [00:25<00:41, 31.71s/it]
 35%|###5      | 0.7093333333333354/2 [00:25<00:41, 32.04s/it]
 36%|###5      | 0.7146666666666688/2 [00:25<00:41, 32.20s/it]
 36%|###6      | 0.7200000000000022/2 [00:25<00:41, 32.09s/it]
 36%|###6      | 0.7253333333333356/2 [00:25<00:40, 31.89s/it]
 37%|###6      | 0.730666666666669/2 [00:26<00:40, 31.78s/it]
 37%|###6      | 0.7360000000000024/2 [00:26<00:40, 31.77s/it]
 37%|###7      | 0.7413333333333358/2 [00:26<00:40, 31.89s/it]
 37%|###7      | 0.7466666666666693/2 [00:26<00:39, 31.67s/it]Train Epoch: 1 [71680/84843 (84%)]        Loss: 1.120914

 38%|###7      | 0.7520000000000027/2 [00:26<00:39, 31.68s/it]
 38%|###7      | 0.7573333333333361/2 [00:27<00:39, 31.85s/it]
 38%|###8      | 0.7626666666666695/2 [00:27<00:39, 31.57s/it]
 38%|###8      | 0.7680000000000029/2 [00:27<00:38, 31.49s/it]
 39%|###8      | 0.7733333333333363/2 [00:27<00:39, 32.02s/it]
 39%|###8      | 0.7786666666666697/2 [00:27<00:38, 31.83s/it]
 39%|###9      | 0.7840000000000031/2 [00:27<00:38, 31.82s/it]
 39%|###9      | 0.7893333333333366/2 [00:28<00:38, 31.93s/it]
 40%|###9      | 0.79466666666667/2 [00:28<00:38, 31.83s/it]
 40%|####      | 0.8000000000000034/2 [00:28<00:38, 31.89s/it]Train Epoch: 1 [76800/84843 (90%)]        Loss: 1.195492

 40%|####      | 0.8053333333333368/2 [00:28<00:38, 31.86s/it]
 41%|####      | 0.8106666666666702/2 [00:28<00:37, 31.56s/it]
 41%|####      | 0.8160000000000036/2 [00:28<00:37, 31.83s/it]
 41%|####1     | 0.821333333333337/2 [00:29<00:37, 31.72s/it]
 41%|####1     | 0.8266666666666704/2 [00:29<00:37, 31.57s/it]
 42%|####1     | 0.8320000000000038/2 [00:29<00:36, 31.64s/it]
 42%|####1     | 0.8373333333333373/2 [00:29<00:36, 31.43s/it]
 42%|####2     | 0.8426666666666707/2 [00:29<00:36, 31.45s/it]
 42%|####2     | 0.8480000000000041/2 [00:29<00:36, 31.54s/it]
 43%|####2     | 0.8533333333333375/2 [00:30<00:36, 31.72s/it]Train Epoch: 1 [81920/84843 (96%)]        Loss: 1.064560

 43%|####2     | 0.8586666666666709/2 [00:30<00:36, 32.06s/it]
 43%|####3     | 0.8640000000000043/2 [00:30<00:36, 32.07s/it]
 43%|####3     | 0.8693333333333377/2 [00:30<00:36, 32.00s/it]
 44%|####3     | 0.8746666666666711/2 [00:30<00:35, 31.88s/it]
 44%|####4     | 0.8800000000000046/2 [00:30<00:35, 31.82s/it]
 44%|####4     | 0.885333333333338/2 [00:31<00:33, 29.78s/it]
 45%|####4     | 0.8906666666666714/2 [00:31<00:33, 29.85s/it]
 45%|####4     | 0.8960000000000048/2 [00:31<00:32, 29.65s/it]
 45%|####5     | 0.9013333333333382/2 [00:31<00:32, 29.68s/it]
 45%|####5     | 0.9066666666666716/2 [00:31<00:32, 29.84s/it]
 46%|####5     | 0.912000000000005/2 [00:31<00:32, 29.89s/it]
 46%|####5     | 0.9173333333333384/2 [00:31<00:32, 29.99s/it]
 46%|####6     | 0.9226666666666719/2 [00:32<00:32, 30.10s/it]
 46%|####6     | 0.9280000000000053/2 [00:32<00:32, 29.88s/it]
 47%|####6     | 0.9333333333333387/2 [00:32<00:31, 29.91s/it]
 47%|####6     | 0.9386666666666721/2 [00:32<00:32, 30.24s/it]
 47%|####7     | 0.9440000000000055/2 [00:32<00:31, 30.20s/it]
 47%|####7     | 0.9493333333333389/2 [00:32<00:31, 30.17s/it]
 48%|####7     | 0.9546666666666723/2 [00:33<00:31, 30.11s/it]
 48%|####8     | 0.9600000000000057/2 [00:33<00:31, 30.18s/it]
 48%|####8     | 0.9653333333333391/2 [00:33<00:30, 29.91s/it]
 49%|####8     | 0.9706666666666726/2 [00:33<00:30, 30.06s/it]
 49%|####8     | 0.976000000000006/2 [00:33<00:31, 31.04s/it]
 49%|####9     | 0.9813333333333394/2 [00:33<00:31, 30.67s/it]
 49%|####9     | 0.9866666666666728/2 [00:34<00:30, 30.50s/it]
 50%|####9     | 0.9920000000000062/2 [00:34<00:30, 30.40s/it]
 50%|####9     | 0.9973333333333396/2 [00:34<00:30, 30.27s/it]
Test Epoch: 1   Accuracy: 6052/11005 (55%)

Train Epoch: 2 [0/84843 (0%)]   Loss: 1.176855

 50%|#####     | 1.0026666666666728/2 [00:34<00:30, 30.71s/it]
 50%|#####     | 1.008000000000006/2 [00:34<00:30, 30.75s/it]
 51%|#####     | 1.0133333333333392/2 [00:34<00:30, 30.96s/it]
 51%|#####     | 1.0186666666666724/2 [00:35<00:30, 31.20s/it]
 51%|#####1    | 1.0240000000000056/2 [00:35<00:30, 31.49s/it]
 51%|#####1    | 1.0293333333333388/2 [00:35<00:30, 31.82s/it]
 52%|#####1    | 1.034666666666672/2 [00:35<00:30, 31.77s/it]
 52%|#####2    | 1.0400000000000051/2 [00:35<00:30, 31.86s/it]
 52%|#####2    | 1.0453333333333383/2 [00:35<00:30, 31.54s/it]
 53%|#####2    | 1.0506666666666715/2 [00:36<00:29, 31.41s/it]Train Epoch: 2 [5120/84843 (6%)]  Loss: 1.009675

 53%|#####2    | 1.0560000000000047/2 [00:36<00:30, 31.94s/it]
 53%|#####3    | 1.061333333333338/2 [00:36<00:29, 31.71s/it]
 53%|#####3    | 1.066666666666671/2 [00:36<00:29, 31.66s/it]
 54%|#####3    | 1.0720000000000043/2 [00:36<00:29, 31.47s/it]
 54%|#####3    | 1.0773333333333375/2 [00:36<00:28, 31.37s/it]
 54%|#####4    | 1.0826666666666707/2 [00:37<00:28, 31.34s/it]
 54%|#####4    | 1.0880000000000039/2 [00:37<00:28, 31.36s/it]
 55%|#####4    | 1.093333333333337/2 [00:37<00:28, 31.58s/it]
 55%|#####4    | 1.0986666666666702/2 [00:37<00:28, 31.46s/it]
 55%|#####5    | 1.1040000000000034/2 [00:37<00:28, 31.33s/it]Train Epoch: 2 [10240/84843 (12%)]        Loss: 0.983963

 55%|#####5    | 1.1093333333333366/2 [00:37<00:28, 31.52s/it]
 56%|#####5    | 1.1146666666666698/2 [00:38<00:27, 31.40s/it]
 56%|#####6    | 1.120000000000003/2 [00:38<00:27, 31.33s/it]
 56%|#####6    | 1.1253333333333362/2 [00:38<00:27, 31.24s/it]
 57%|#####6    | 1.1306666666666694/2 [00:38<00:27, 31.35s/it]
 57%|#####6    | 1.1360000000000026/2 [00:38<00:27, 31.60s/it]
 57%|#####7    | 1.1413333333333358/2 [00:38<00:27, 31.82s/it]
 57%|#####7    | 1.146666666666669/2 [00:39<00:27, 31.71s/it]
 58%|#####7    | 1.1520000000000021/2 [00:39<00:26, 31.78s/it]
 58%|#####7    | 1.1573333333333353/2 [00:39<00:26, 31.55s/it]Train Epoch: 2 [15360/84843 (18%)]        Loss: 0.854694

 58%|#####8    | 1.1626666666666685/2 [00:39<00:26, 32.00s/it]
 58%|#####8    | 1.1680000000000017/2 [00:39<00:26, 32.06s/it]
 59%|#####8    | 1.173333333333335/2 [00:39<00:26, 32.09s/it]
 59%|#####8    | 1.178666666666668/2 [00:40<00:26, 32.00s/it]
 59%|#####9    | 1.1840000000000013/2 [00:40<00:26, 32.00s/it]
 59%|#####9    | 1.1893333333333345/2 [00:40<00:25, 31.85s/it]
 60%|#####9    | 1.1946666666666677/2 [00:40<00:25, 31.60s/it]
 60%|######    | 1.2000000000000008/2 [00:40<00:25, 31.99s/it]
 60%|######    | 1.205333333333334/2 [00:41<00:25, 31.91s/it]
 61%|######    | 1.2106666666666672/2 [00:41<00:25, 31.86s/it]Train Epoch: 2 [20480/84843 (24%)]        Loss: 0.923911

 61%|######    | 1.2160000000000004/2 [00:41<00:25, 31.89s/it]
 61%|######1   | 1.2213333333333336/2 [00:41<00:24, 31.84s/it]
 61%|######1   | 1.2266666666666668/2 [00:41<00:24, 31.66s/it]
 62%|######1   | 1.232/2 [00:41<00:24, 31.95s/it]
 62%|######1   | 1.2373333333333332/2 [00:42<00:24, 31.97s/it]
 62%|######2   | 1.2426666666666664/2 [00:42<00:24, 31.91s/it]
 62%|######2   | 1.2479999999999996/2 [00:42<00:23, 31.84s/it]
 63%|######2   | 1.2533333333333327/2 [00:42<00:23, 31.87s/it]
 63%|######2   | 1.258666666666666/2 [00:42<00:23, 31.79s/it]
 63%|######3   | 1.2639999999999991/2 [00:42<00:23, 31.64s/it]Train Epoch: 2 [25600/84843 (30%)]        Loss: 1.085688

 63%|######3   | 1.2693333333333323/2 [00:43<00:23, 32.04s/it]
 64%|######3   | 1.2746666666666655/2 [00:43<00:22, 31.66s/it]
 64%|######3   | 1.2799999999999987/2 [00:43<00:22, 31.72s/it]
 64%|######4   | 1.2853333333333319/2 [00:43<00:22, 31.45s/it]
 65%|######4   | 1.290666666666665/2 [00:43<00:22, 31.38s/it]
 65%|######4   | 1.2959999999999983/2 [00:43<00:22, 31.40s/it]
 65%|######5   | 1.3013333333333315/2 [00:44<00:21, 31.40s/it]
 65%|######5   | 1.3066666666666646/2 [00:44<00:21, 31.34s/it]
 66%|######5   | 1.3119999999999978/2 [00:44<00:21, 31.12s/it]
 66%|######5   | 1.317333333333331/2 [00:44<00:21, 31.31s/it] Train Epoch: 2 [30720/84843 (36%)]        Loss: 0.894427

 66%|######6   | 1.3226666666666642/2 [00:44<00:21, 31.47s/it]
 66%|######6   | 1.3279999999999974/2 [00:44<00:20, 30.94s/it]
 67%|######6   | 1.3333333333333306/2 [00:45<00:20, 31.15s/it]
 67%|######6   | 1.3386666666666638/2 [00:45<00:20, 31.16s/it]
 67%|######7   | 1.343999999999997/2 [00:45<00:20, 31.32s/it]
 67%|######7   | 1.3493333333333302/2 [00:45<00:20, 31.36s/it]
 68%|######7   | 1.3546666666666634/2 [00:45<00:20, 31.26s/it]
 68%|######7   | 1.3599999999999965/2 [00:45<00:19, 31.17s/it]
 68%|######8   | 1.3653333333333297/2 [00:46<00:19, 31.11s/it]
 69%|######8   | 1.370666666666663/2 [00:46<00:19, 31.05s/it] Train Epoch: 2 [35840/84843 (42%)]        Loss: 0.971906

 69%|######8   | 1.3759999999999961/2 [00:46<00:19, 31.32s/it]
 69%|######9   | 1.3813333333333293/2 [00:46<00:19, 31.34s/it]
 69%|######9   | 1.3866666666666625/2 [00:46<00:19, 31.50s/it]
 70%|######9   | 1.3919999999999957/2 [00:46<00:19, 31.31s/it]
 70%|######9   | 1.3973333333333289/2 [00:47<00:18, 31.31s/it]
 70%|#######   | 1.402666666666662/2 [00:47<00:18, 31.34s/it]
 70%|#######   | 1.4079999999999953/2 [00:47<00:18, 31.55s/it]
 71%|#######   | 1.4133333333333284/2 [00:47<00:18, 31.50s/it]
 71%|#######   | 1.4186666666666616/2 [00:47<00:18, 31.43s/it]
 71%|#######1  | 1.4239999999999948/2 [00:47<00:18, 31.36s/it]Train Epoch: 2 [40960/84843 (48%)]        Loss: 1.002216

 71%|#######1  | 1.429333333333328/2 [00:48<00:17, 31.47s/it]
 72%|#######1  | 1.4346666666666612/2 [00:48<00:17, 31.02s/it]
 72%|#######1  | 1.4399999999999944/2 [00:48<00:17, 30.94s/it]
 72%|#######2  | 1.4453333333333276/2 [00:48<00:17, 31.24s/it]
 73%|#######2  | 1.4506666666666608/2 [00:48<00:17, 31.25s/it]
 73%|#######2  | 1.455999999999994/2 [00:48<00:17, 31.42s/it]
 73%|#######3  | 1.4613333333333272/2 [00:49<00:16, 31.33s/it]
 73%|#######3  | 1.4666666666666603/2 [00:49<00:16, 31.19s/it]
 74%|#######3  | 1.4719999999999935/2 [00:49<00:16, 31.24s/it]
 74%|#######3  | 1.4773333333333267/2 [00:49<00:16, 31.23s/it]Train Epoch: 2 [46080/84843 (54%)]        Loss: 0.820399

 74%|#######4  | 1.48266666666666/2 [00:49<00:16, 31.30s/it]
 74%|#######4  | 1.487999999999993/2 [00:49<00:16, 31.30s/it]
 75%|#######4  | 1.4933333333333263/2 [00:50<00:15, 31.31s/it]
 75%|#######4  | 1.4986666666666595/2 [00:50<00:15, 31.14s/it]
 75%|#######5  | 1.5039999999999927/2 [00:50<00:15, 31.20s/it]
 75%|#######5  | 1.5093333333333259/2 [00:50<00:15, 31.04s/it]
 76%|#######5  | 1.514666666666659/2 [00:50<00:15, 31.13s/it]
 76%|#######5  | 1.5199999999999922/2 [00:50<00:14, 31.16s/it]
 76%|#######6  | 1.5253333333333254/2 [00:51<00:14, 31.59s/it]
 77%|#######6  | 1.5306666666666586/2 [00:51<00:14, 31.87s/it]Train Epoch: 2 [51200/84843 (60%)]        Loss: 0.811157

 77%|#######6  | 1.5359999999999918/2 [00:51<00:14, 32.27s/it]
 77%|#######7  | 1.541333333333325/2 [00:51<00:14, 31.63s/it]
 77%|#######7  | 1.5466666666666582/2 [00:51<00:14, 31.63s/it]
 78%|#######7  | 1.5519999999999914/2 [00:51<00:14, 31.85s/it]
 78%|#######7  | 1.5573333333333246/2 [00:52<00:13, 31.59s/it]
 78%|#######8  | 1.5626666666666578/2 [00:52<00:13, 31.77s/it]
 78%|#######8  | 1.567999999999991/2 [00:52<00:13, 31.89s/it]
 79%|#######8  | 1.5733333333333241/2 [00:52<00:13, 32.17s/it]
 79%|#######8  | 1.5786666666666573/2 [00:52<00:13, 32.04s/it]
 79%|#######9  | 1.5839999999999905/2 [00:52<00:13, 31.83s/it]Train Epoch: 2 [56320/84843 (66%)]        Loss: 0.875054

 79%|#######9  | 1.5893333333333237/2 [00:53<00:13, 32.36s/it]
 80%|#######9  | 1.594666666666657/2 [00:53<00:12, 32.01s/it]
 80%|#######9  | 1.59999999999999/2 [00:53<00:12, 31.93s/it]
 80%|########  | 1.6053333333333233/2 [00:53<00:12, 32.12s/it]
 81%|########  | 1.6106666666666565/2 [00:53<00:12, 31.99s/it]
 81%|########  | 1.6159999999999897/2 [00:53<00:12, 31.66s/it]
 81%|########1 | 1.6213333333333229/2 [00:54<00:11, 31.45s/it]
 81%|########1 | 1.626666666666656/2 [00:54<00:11, 31.64s/it]
 82%|########1 | 1.6319999999999892/2 [00:54<00:11, 31.58s/it]
 82%|########1 | 1.6373333333333224/2 [00:54<00:11, 31.76s/it]Train Epoch: 2 [61440/84843 (72%)]        Loss: 0.923683

 82%|########2 | 1.6426666666666556/2 [00:54<00:11, 31.94s/it]
 82%|########2 | 1.6479999999999888/2 [00:54<00:11, 31.99s/it]
 83%|########2 | 1.653333333333322/2 [00:55<00:11, 31.86s/it]
 83%|########2 | 1.6586666666666552/2 [00:55<00:10, 31.85s/it]
 83%|########3 | 1.6639999999999884/2 [00:55<00:10, 31.66s/it]
 83%|########3 | 1.6693333333333216/2 [00:55<00:10, 31.57s/it]
 84%|########3 | 1.6746666666666548/2 [00:55<00:10, 31.46s/it]
 84%|########3 | 1.679999999999988/2 [00:55<00:10, 31.60s/it]
 84%|########4 | 1.6853333333333211/2 [00:56<00:09, 31.72s/it]
 85%|########4 | 1.6906666666666543/2 [00:56<00:10, 33.01s/it]Train Epoch: 2 [66560/84843 (78%)]        Loss: 0.889358

 85%|########4 | 1.6959999999999875/2 [00:56<00:10, 33.06s/it]
 85%|########5 | 1.7013333333333207/2 [00:56<00:09, 32.27s/it]
 85%|########5 | 1.706666666666654/2 [00:56<00:09, 32.09s/it]
 86%|########5 | 1.711999999999987/2 [00:57<00:09, 31.90s/it]
 86%|########5 | 1.7173333333333203/2 [00:57<00:08, 31.69s/it]
 86%|########6 | 1.7226666666666535/2 [00:57<00:08, 31.81s/it]
 86%|########6 | 1.7279999999999867/2 [00:57<00:08, 31.87s/it]
 87%|########6 | 1.7333333333333198/2 [00:57<00:08, 31.85s/it]
 87%|########6 | 1.738666666666653/2 [00:57<00:08, 31.93s/it]
 87%|########7 | 1.7439999999999862/2 [00:58<00:08, 31.82s/it]Train Epoch: 2 [71680/84843 (84%)]        Loss: 0.826419

 87%|########7 | 1.7493333333333194/2 [00:58<00:08, 32.33s/it]
 88%|########7 | 1.7546666666666526/2 [00:58<00:07, 31.73s/it]
 88%|########7 | 1.7599999999999858/2 [00:58<00:07, 31.62s/it]
 88%|########8 | 1.765333333333319/2 [00:58<00:07, 31.65s/it]
 89%|########8 | 1.7706666666666522/2 [00:58<00:07, 31.88s/it]
 89%|########8 | 1.7759999999999854/2 [00:59<00:07, 31.85s/it]
 89%|########9 | 1.7813333333333186/2 [00:59<00:06, 31.80s/it]
 89%|########9 | 1.7866666666666517/2 [00:59<00:06, 31.78s/it]
 90%|########9 | 1.791999999999985/2 [00:59<00:06, 31.60s/it]
 90%|########9 | 1.7973333333333181/2 [00:59<00:06, 31.79s/it]Train Epoch: 2 [76800/84843 (90%)]        Loss: 1.038377

 90%|######### | 1.8026666666666513/2 [00:59<00:06, 31.90s/it]
 90%|######### | 1.8079999999999845/2 [01:00<00:06, 31.64s/it]
 91%|######### | 1.8133333333333177/2 [01:00<00:05, 31.96s/it]
 91%|######### | 1.8186666666666509/2 [01:00<00:05, 31.57s/it]
 91%|#########1| 1.823999999999984/2 [01:00<00:05, 31.51s/it]
 91%|#########1| 1.8293333333333173/2 [01:00<00:05, 31.72s/it]
 92%|#########1| 1.8346666666666505/2 [01:00<00:05, 31.34s/it]
 92%|#########1| 1.8399999999999836/2 [01:01<00:04, 31.20s/it]
 92%|#########2| 1.8453333333333168/2 [01:01<00:04, 31.33s/it]
 93%|#########2| 1.85066666666665/2 [01:01<00:04, 31.29s/it]  Train Epoch: 2 [81920/84843 (96%)]        Loss: 0.769750

 93%|#########2| 1.8559999999999832/2 [01:01<00:04, 31.47s/it]
 93%|#########3| 1.8613333333333164/2 [01:01<00:04, 31.12s/it]
 93%|#########3| 1.8666666666666496/2 [01:01<00:04, 31.24s/it]
 94%|#########3| 1.8719999999999828/2 [01:02<00:03, 31.21s/it]
 94%|#########3| 1.877333333333316/2 [01:02<00:03, 31.37s/it]
 94%|#########4| 1.8826666666666492/2 [01:02<00:03, 31.47s/it]
 94%|#########4| 1.8879999999999824/2 [01:02<00:03, 29.55s/it]
 95%|#########4| 1.8933333333333155/2 [01:02<00:03, 29.54s/it]
 95%|#########4| 1.8986666666666487/2 [01:02<00:02, 29.52s/it]
 95%|#########5| 1.903999999999982/2 [01:03<00:02, 29.42s/it]
 95%|#########5| 1.9093333333333151/2 [01:03<00:02, 29.38s/it]
 96%|#########5| 1.9146666666666483/2 [01:03<00:02, 29.33s/it]
 96%|#########5| 1.9199999999999815/2 [01:03<00:02, 29.38s/it]
 96%|#########6| 1.9253333333333147/2 [01:03<00:02, 29.27s/it]
 97%|#########6| 1.9306666666666479/2 [01:03<00:02, 29.16s/it]
 97%|#########6| 1.935999999999981/2 [01:03<00:01, 29.07s/it]
 97%|#########7| 1.9413333333333143/2 [01:04<00:01, 29.09s/it]
 97%|#########7| 1.9466666666666474/2 [01:04<00:01, 29.45s/it]
 98%|#########7| 1.9519999999999806/2 [01:04<00:01, 29.40s/it]
 98%|#########7| 1.9573333333333138/2 [01:04<00:01, 29.40s/it]
 98%|#########8| 1.962666666666647/2 [01:04<00:01, 29.27s/it]
 98%|#########8| 1.9679999999999802/2 [01:04<00:00, 29.33s/it]
 99%|#########8| 1.9733333333333134/2 [01:05<00:00, 29.50s/it]
 99%|#########8| 1.9786666666666466/2 [01:05<00:00, 29.39s/it]
 99%|#########9| 1.9839999999999798/2 [01:05<00:00, 29.41s/it]
 99%|#########9| 1.989333333333313/2 [01:05<00:00, 29.40s/it]
100%|#########9| 1.9946666666666462/2 [01:05<00:00, 29.43s/it]
100%|#########9| 1.9999999999999793/2 [01:05<00:00, 29.19s/it]
Test Epoch: 2   Accuracy: 7852/11005 (71%)


100%|#########9| 1.9999999999999793/2 [01:05<00:00, 32.91s/it]

The network should be more than 65% accurate on the test set after 2 epochs, and 85% after 21 epochs. Let’s look at the last words in the train set, and see how the model did on it.

def predict(tensor):
    # Use the model to predict the label of the waveform
    tensor = tensor.to(device)
    tensor = transform(tensor)
    tensor = model(tensor.unsqueeze(0))
    tensor = get_likely_index(tensor)
    tensor = index_to_label(tensor.squeeze())
    return tensor


waveform, sample_rate, utterance, *_ = train_set[-1]
ipd.Audio(waveform.numpy(), rate=sample_rate)

print(f"Expected: {utterance}. Predicted: {predict(waveform)}.")
Expected: zero. Predicted: zero.

Let’s find an example that isn’t classified correctly, if there is one.

for i, (waveform, sample_rate, utterance, *_) in enumerate(test_set):
    output = predict(waveform)
    if output != utterance:
        ipd.Audio(waveform.numpy(), rate=sample_rate)
        print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
        break
else:
    print("All examples in this dataset were correctly classified!")
    print("In this case, let's just look at the last data point")
    ipd.Audio(waveform.numpy(), rate=sample_rate)
    print(f"Data point #{i}. Expected: {utterance}. Predicted: {output}.")
Data point #1. Expected: right. Predicted: seven.

Feel free to try with one of your own recordings of one of the labels! For example, using Colab, say “Go” while executing the cell below. This will record one second of audio and try to classify it.

def record(seconds=1):

    from google.colab import output as colab_output
    from base64 import b64decode
    from io import BytesIO
    from pydub import AudioSegment

    RECORD = (
        b"const sleep  = time => new Promise(resolve => setTimeout(resolve, time))\n"
        b"const b2text = blob => new Promise(resolve => {\n"
        b"  const reader = new FileReader()\n"
        b"  reader.onloadend = e => resolve(e.srcElement.result)\n"
        b"  reader.readAsDataURL(blob)\n"
        b"})\n"
        b"var record = time => new Promise(async resolve => {\n"
        b"  stream = await navigator.mediaDevices.getUserMedia({ audio: true })\n"
        b"  recorder = new MediaRecorder(stream)\n"
        b"  chunks = []\n"
        b"  recorder.ondataavailable = e => chunks.push(e.data)\n"
        b"  recorder.start()\n"
        b"  await sleep(time)\n"
        b"  recorder.onstop = async ()=>{\n"
        b"    blob = new Blob(chunks)\n"
        b"    text = await b2text(blob)\n"
        b"    resolve(text)\n"
        b"  }\n"
        b"  recorder.stop()\n"
        b"})"
    )
    RECORD = RECORD.decode("ascii")

    print(f"Recording started for {seconds} seconds.")
    display(ipd.Javascript(RECORD))
    s = colab_output.eval_js("record(%d)" % (seconds * 1000))
    print("Recording ended.")
    b = b64decode(s.split(",")[1])

    fileformat = "wav"
    filename = f"_audio.{fileformat}"
    AudioSegment.from_file(BytesIO(b)).export(filename, format=fileformat)
    return torchaudio.load(filename)


# Detect whether notebook runs in google colab
if "google.colab" in sys.modules:
    waveform, sample_rate = record()
    print(f"Predicted: {predict(waveform)}.")
    ipd.Audio(waveform.numpy(), rate=sample_rate)

Conclusion

In this tutorial, we used torchaudio to load a dataset and resample the signal. We have then defined a neural network that we trained to recognize a given command. There are also other data preprocessing methods, such as finding the mel frequency cepstral coefficients (MFCC), that can reduce the size of the dataset. This transform is also available in torchaudio as torchaudio.transforms.MFCC.

Total running time of the script: ( 2 minutes 29.974 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources