.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/char_rnn_classification_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_char_rnn_classification_tutorial.py: NLP From Scratch: Classifying Names with a Character-Level RNN ************************************************************** **Author**: `Sean Robertson `_ We will be building and training a basic character-level Recurrent Neural Network (RNN) to classify words. This tutorial, along with two other Natural Language Processing (NLP) "from scratch" tutorials :doc:`/intermediate/char_rnn_generation_tutorial` and :doc:`/intermediate/seq2seq_translation_tutorial`, show how to preprocess data to model NLP. In particular these tutorials do not use many of the convenience functions of `torchtext`, so you can see how preprocessing to model NLP works at a low level. A character-level RNN reads words as a series of characters - outputting a prediction and "hidden state" at each step, feeding its previous hidden state into each next step. We take the final prediction to be the output, i.e. which class the word belongs to. Specifically, we'll train on a few thousand surnames from 18 languages of origin, and predict which language a name is from based on the spelling: .. code-block:: sh $ python predict.py Hinton (-0.47) Scottish (-1.52) English (-3.57) Irish $ python predict.py Schmidhuber (-0.19) German (-2.48) Czech (-2.68) Dutch Recommended Preparation ======================= Before starting this tutorial it is recommended that you have installed PyTorch, and have a basic understanding of Python programming language and Tensors: - https://pytorch.org/ For installation instructions - :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general and learn the basics of Tensors - :doc:`/beginner/pytorch_with_examples` for a wide and deep overview - :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user It would also be useful to know about RNNs and how they work: - `The Unreasonable Effectiveness of Recurrent Neural Networks `__ shows a bunch of real life examples - `Understanding LSTM Networks `__ is about LSTMs specifically but also informative about RNNs in general Preparing the Data ================== .. note:: Download the data from `here `_ and extract it to the current directory. Included in the ``data/names`` directory are 18 text files named as ``[Language].txt``. Each file contains a bunch of names, one name per line, mostly romanized (but we still need to convert from Unicode to ASCII). We'll end up with a dictionary of lists of names per language, ``{language: [names ...]}``. The generic variables "category" and "line" (for language and name in our case) are used for later extensibility. .. GENERATED FROM PYTHON SOURCE LINES 77-119 .. code-block:: default from io import open import glob import os def findFiles(path): return glob.glob(path) print(findFiles('data/names/*.txt')) import unicodedata import string all_letters = string.ascii_letters + " .,;'" n_letters = len(all_letters) # Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 def unicodeToAscii(s): return ''.join( c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn' and c in all_letters ) print(unicodeToAscii('Ślusàrski')) # Build the category_lines dictionary, a list of names per language category_lines = {} all_categories = [] # Read a file and split into lines def readLines(filename): lines = open(filename, encoding='utf-8').read().strip().split('\n') return [unicodeToAscii(line) for line in lines] for filename in findFiles('data/names/*.txt'): category = os.path.splitext(os.path.basename(filename))[0] all_categories.append(category) lines = readLines(filename) category_lines[category] = lines n_categories = len(all_categories) .. rst-class:: sphx-glr-script-out .. code-block:: none ['data/names/Arabic.txt', 'data/names/Chinese.txt', 'data/names/Czech.txt', 'data/names/Dutch.txt', 'data/names/English.txt', 'data/names/French.txt', 'data/names/German.txt', 'data/names/Greek.txt', 'data/names/Irish.txt', 'data/names/Italian.txt', 'data/names/Japanese.txt', 'data/names/Korean.txt', 'data/names/Polish.txt', 'data/names/Portuguese.txt', 'data/names/Russian.txt', 'data/names/Scottish.txt', 'data/names/Spanish.txt', 'data/names/Vietnamese.txt'] Slusarski .. GENERATED FROM PYTHON SOURCE LINES 120-125 Now we have ``category_lines``, a dictionary mapping each category (language) to a list of lines (names). We also kept track of ``all_categories`` (just a list of languages) and ``n_categories`` for later reference. .. GENERATED FROM PYTHON SOURCE LINES 125-129 .. code-block:: default print(category_lines['Italian'][:5]) .. rst-class:: sphx-glr-script-out .. code-block:: none ['Abandonato', 'Abatangelo', 'Abatantuono', 'Abate', 'Abategiovanni'] .. GENERATED FROM PYTHON SOURCE LINES 130-146 Turning Names into Tensors -------------------------- Now that we have all the names organized, we need to turn them into Tensors to make any use of them. To represent a single letter, we use a "one-hot vector" of size ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1 at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``. To make a word we join a bunch of those into a 2D matrix ````. That extra 1 dimension is because PyTorch assumes everything is in batches - we're just using a batch size of 1 here. .. GENERATED FROM PYTHON SOURCE LINES 146-172 .. code-block:: default import torch # Find letter index from all_letters, e.g. "a" = 0 def letterToIndex(letter): return all_letters.find(letter) # Just for demonstration, turn a letter into a <1 x n_letters> Tensor def letterToTensor(letter): tensor = torch.zeros(1, n_letters) tensor[0][letterToIndex(letter)] = 1 return tensor # Turn a line into a , # or an array of one-hot letter vectors def lineToTensor(line): tensor = torch.zeros(len(line), 1, n_letters) for li, letter in enumerate(line): tensor[li][0][letterToIndex(letter)] = 1 return tensor print(letterToTensor('J')) print(lineToTensor('Jones').size()) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]) torch.Size([5, 1, 57]) .. GENERATED FROM PYTHON SOURCE LINES 173-186 Creating the Network ==================== Before autograd, creating a recurrent neural network in Torch involved cloning the parameters of a layer over several timesteps. The layers held hidden state and gradients which are now entirely handled by the graph itself. This means you can implement a RNN in a very "pure" way, as regular feed-forward layers. This RNN module implements a "vanilla RNN" an is just 3 linear layers which operate on an input and hidden state, with a ``LogSoftmax`` layer after the output. .. GENERATED FROM PYTHON SOURCE LINES 186-214 .. code-block:: default import torch.nn as nn import torch.nn.functional as F class RNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(RNN, self).__init__() self.hidden_size = hidden_size self.i2h = nn.Linear(input_size, hidden_size) self.h2h = nn.Linear(hidden_size, hidden_size) self.h2o = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden): hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) output = self.h2o(hidden) output = self.softmax(output) return output, hidden def initHidden(self): return torch.zeros(1, self.hidden_size) n_hidden = 128 rnn = RNN(n_letters, n_hidden, n_categories) .. GENERATED FROM PYTHON SOURCE LINES 215-221 To run a step of this network we need to pass an input (in our case, the Tensor for the current letter) and a previous hidden state (which we initialize as zeros at first). We'll get back the output (probability of each language) and a next hidden state (which we keep for the next step). .. GENERATED FROM PYTHON SOURCE LINES 221-228 .. code-block:: default input = letterToTensor('A') hidden = torch.zeros(1, n_hidden) output, next_hidden = rnn(input, hidden) .. GENERATED FROM PYTHON SOURCE LINES 229-234 For the sake of efficiency we don't want to be creating a new Tensor for every step, so we will use ``lineToTensor`` instead of ``letterToTensor`` and use slices. This could be further optimized by precomputing batches of Tensors. .. GENERATED FROM PYTHON SOURCE LINES 234-242 .. code-block:: default input = lineToTensor('Albert') hidden = torch.zeros(1, n_hidden) output, next_hidden = rnn(input[0], hidden) print(output) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[-3.0014, -2.8677, -2.9758, -2.9196, -3.1387, -2.8728, -2.8886, -2.8754, -2.5694, -2.8957, -2.8363, -2.9602, -2.9206, -2.8656, -2.8350, -2.7372, -3.0470, -2.9479]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 243-246 As you can see the output is a ``<1 x n_categories>`` Tensor, where every item is the likelihood of that category (higher is more likely). .. GENERATED FROM PYTHON SOURCE LINES 249-259 Training ======== Preparing for Training ---------------------- Before going into training we should make a few helper functions. The first is to interpret the output of the network, which we know to be a likelihood of each category. We can use ``Tensor.topk`` to get the index of the greatest value: .. GENERATED FROM PYTHON SOURCE LINES 260-269 .. code-block:: default def categoryFromOutput(output): top_n, top_i = output.topk(1) category_i = top_i[0].item() return all_categories[category_i], category_i print(categoryFromOutput(output)) .. rst-class:: sphx-glr-script-out .. code-block:: none ('Irish', 8) .. GENERATED FROM PYTHON SOURCE LINES 270-273 We will also want a quick way to get a training example (a name and its language): .. GENERATED FROM PYTHON SOURCE LINES 273-291 .. code-block:: default import random def randomChoice(l): return l[random.randint(0, len(l) - 1)] def randomTrainingExample(): category = randomChoice(all_categories) line = randomChoice(category_lines[category]) category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long) line_tensor = lineToTensor(line) return category, line, category_tensor, line_tensor for i in range(10): category, line, category_tensor, line_tensor = randomTrainingExample() print('category =', category, '/ line =', line) .. rst-class:: sphx-glr-script-out .. code-block:: none category = Chinese / line = Hou category = Scottish / line = Mckay category = Arabic / line = Cham category = Russian / line = V'Yurkov category = Irish / line = O'Keeffe category = French / line = Belrose category = Spanish / line = Silva category = Japanese / line = Fuchida category = Greek / line = Tsahalis category = Korean / line = Chang .. GENERATED FROM PYTHON SOURCE LINES 292-301 Training the Network -------------------- Now all it takes to train this network is show it a bunch of examples, have it make guesses, and tell it if it's wrong. For the loss function ``nn.NLLLoss`` is appropriate, since the last layer of the RNN is ``nn.LogSoftmax``. .. GENERATED FROM PYTHON SOURCE LINES 301-305 .. code-block:: default criterion = nn.NLLLoss() .. GENERATED FROM PYTHON SOURCE LINES 306-318 Each loop of training will: - Create input and target tensors - Create a zeroed initial hidden state - Read each letter in and - Keep hidden state for next letter - Compare final output to target - Back-propagate - Return the output and loss .. GENERATED FROM PYTHON SOURCE LINES 318-339 .. code-block:: default learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn def train(category_tensor, line_tensor): hidden = rnn.initHidden() rnn.zero_grad() for i in range(line_tensor.size()[0]): output, hidden = rnn(line_tensor[i], hidden) loss = criterion(output, category_tensor) loss.backward() # Add parameters' gradients to their values, multiplied by learning rate for p in rnn.parameters(): p.data.add_(p.grad.data, alpha=-learning_rate) return output, loss.item() .. GENERATED FROM PYTHON SOURCE LINES 340-346 Now we just have to run that with a bunch of examples. Since the ``train`` function returns both the output and loss we can print its guesses and also keep track of loss for plotting. Since there are 1000s of examples we print only every ``print_every`` examples, and take an average of the loss. .. GENERATED FROM PYTHON SOURCE LINES 346-386 .. code-block:: default import time import math n_iters = 100000 print_every = 5000 plot_every = 1000 # Keep track of losses for plotting current_loss = 0 all_losses = [] def timeSince(since): now = time.time() s = now - since m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) start = time.time() for iter in range(1, n_iters + 1): category, line, category_tensor, line_tensor = randomTrainingExample() output, loss = train(category_tensor, line_tensor) current_loss += loss # Print ``iter`` number, loss, name and guess if iter % print_every == 0: guess, guess_i = categoryFromOutput(output) correct = '✓' if guess == category else '✗ (%s)' % category print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct)) # Add current loss avg to list of losses if iter % plot_every == 0: all_losses.append(current_loss / plot_every) current_loss = 0 .. rst-class:: sphx-glr-script-out .. code-block:: none 5000 5% (0m 33s) 2.2208 Horigome / Japanese ✓ 10000 10% (1m 8s) 1.6752 Miazga / Japanese ✗ (Polish) 15000 15% (1m 43s) 0.1778 Yukhvidov / Russian ✓ 20000 20% (2m 17s) 1.5856 Mclaughlin / Irish ✗ (Scottish) 25000 25% (2m 52s) 0.6552 Banh / Vietnamese ✓ 30000 30% (3m 27s) 1.5547 Machado / Japanese ✗ (Portuguese) 35000 35% (4m 2s) 0.0168 Fotopoulos / Greek ✓ 40000 40% (4m 37s) 1.1464 Quirke / Irish ✓ 45000 45% (5m 11s) 1.7532 Reier / French ✗ (German) 50000 50% (5m 46s) 0.8413 Hou / Chinese ✓ 55000 55% (6m 21s) 0.8587 Duan / Vietnamese ✗ (Chinese) 60000 60% (6m 56s) 0.2047 Giang / Vietnamese ✓ 65000 65% (7m 30s) 2.5534 Cober / French ✗ (Czech) 70000 70% (8m 5s) 1.5163 Mateus / Arabic ✗ (Portuguese) 75000 75% (8m 39s) 0.2217 Hamilton / Scottish ✓ 80000 80% (9m 14s) 0.4456 Maessen / Dutch ✓ 85000 85% (9m 48s) 0.0239 Gan / Chinese ✓ 90000 90% (10m 23s) 0.0521 Bellomi / Italian ✓ 95000 95% (10m 57s) 0.0867 Vozgov / Russian ✓ 100000 100% (11m 32s) 0.2730 Tong / Vietnamese ✓ .. GENERATED FROM PYTHON SOURCE LINES 387-393 Plotting the Results -------------------- Plotting the historical loss from ``all_losses`` shows the network learning: .. GENERATED FROM PYTHON SOURCE LINES 393-401 .. code-block:: default import matplotlib.pyplot as plt import matplotlib.ticker as ticker plt.figure() plt.plot(all_losses) .. image-sg:: /intermediate/images/sphx_glr_char_rnn_classification_tutorial_001.png :alt: char rnn classification tutorial :srcset: /intermediate/images/sphx_glr_char_rnn_classification_tutorial_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 402-411 Evaluating the Results ====================== To see how well the network performs on different categories, we will create a confusion matrix, indicating for every actual language (rows) which language the network guesses (columns). To calculate the confusion matrix a bunch of samples are run through the network with ``evaluate()``, which is the same as ``train()`` minus the backprop. .. GENERATED FROM PYTHON SOURCE LINES 411-455 .. code-block:: default # Keep track of correct guesses in a confusion matrix confusion = torch.zeros(n_categories, n_categories) n_confusion = 10000 # Just return an output given a line def evaluate(line_tensor): hidden = rnn.initHidden() for i in range(line_tensor.size()[0]): output, hidden = rnn(line_tensor[i], hidden) return output # Go through a bunch of examples and record which are correctly guessed for i in range(n_confusion): category, line, category_tensor, line_tensor = randomTrainingExample() output = evaluate(line_tensor) guess, guess_i = categoryFromOutput(output) category_i = all_categories.index(category) confusion[category_i][guess_i] += 1 # Normalize by dividing every row by its sum for i in range(n_categories): confusion[i] = confusion[i] / confusion[i].sum() # Set up plot fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(confusion.numpy()) fig.colorbar(cax) # Set up axes ax.set_xticklabels([''] + all_categories, rotation=90) ax.set_yticklabels([''] + all_categories) # Force label at every tick ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) # sphinx_gallery_thumbnail_number = 2 plt.show() .. image-sg:: /intermediate/images/sphx_glr_char_rnn_classification_tutorial_002.png :alt: char rnn classification tutorial :srcset: /intermediate/images/sphx_glr_char_rnn_classification_tutorial_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /var/lib/workspace/intermediate_source/char_rnn_classification_tutorial.py:444: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. /var/lib/workspace/intermediate_source/char_rnn_classification_tutorial.py:445: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. .. GENERATED FROM PYTHON SOURCE LINES 456-461 You can pick out bright spots off the main axis that show which languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish for Italian. It seems to do very well with Greek, and very poorly with English (perhaps because of overlap with other languages). .. GENERATED FROM PYTHON SOURCE LINES 464-467 Running on User Input --------------------- .. GENERATED FROM PYTHON SOURCE LINES 467-488 .. code-block:: default def predict(input_line, n_predictions=3): print('\n> %s' % input_line) with torch.no_grad(): output = evaluate(lineToTensor(input_line)) # Get top N categories topv, topi = output.topk(n_predictions, 1, True) predictions = [] for i in range(n_predictions): value = topv[0][i].item() category_index = topi[0][i].item() print('(%.2f) %s' % (value, all_categories[category_index])) predictions.append([value, all_categories[category_index]]) predict('Dovesky') predict('Jackson') predict('Satoshi') .. rst-class:: sphx-glr-script-out .. code-block:: none > Dovesky (-0.23) Czech (-2.02) Russian (-3.35) English > Jackson (-0.20) Scottish (-2.51) Russian (-3.05) Greek > Satoshi (-0.91) Italian (-1.26) Japanese (-1.57) Polish .. GENERATED FROM PYTHON SOURCE LINES 489-513 The final versions of the scripts `in the Practical PyTorch repo `__ split the above code into a few files: - ``data.py`` (loads files) - ``model.py`` (defines the RNN) - ``train.py`` (runs training) - ``predict.py`` (runs ``predict()`` with command line arguments) - ``server.py`` (serve prediction as a JSON API with ``bottle.py``) Run ``train.py`` to train and save the network. Run ``predict.py`` with a name to view predictions: .. code-block:: sh $ python predict.py Hazaki (-0.42) Japanese (-1.39) Polish (-3.51) Czech Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON output of predictions. .. GENERATED FROM PYTHON SOURCE LINES 516-532 Exercises ========= - Try with a different dataset of line -> category, for example: - Any word -> language - First name -> gender - Character name -> writer - Page title -> blog or subreddit - Get better results with a bigger and/or better shaped network - Add more linear layers - Try the ``nn.LSTM`` and ``nn.GRU`` layers - Combine multiple of these RNNs as a higher level network .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 11 minutes 45.055 seconds) .. _sphx_glr_download_intermediate_char_rnn_classification_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: char_rnn_classification_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: char_rnn_classification_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_