• Tutorials >
  • (experimental) Dynamic Quantization on an LSTM Word Language Model
Shortcuts

(experimental) Dynamic Quantization on an LSTM Word Language Model

Author: James Reed

Edited by: Seth Weidman

Introduction

Quantization involves converting the weights and activations of your model from float to int, which can result in smaller model size and faster inference with only a small hit to accuracy.

In this tutorial, we’ll apply the easiest form of quantization - dynamic quantization - to an LSTM-based next word-prediction model, closely following the word language model from the PyTorch examples.

# imports
import os
from io import open
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

1. Define the model

Here we define the LSTM model architecture, following the model from the word language model example.

class LSTMModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        return decoded, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

2. Load in the text data

Next, we load the Wikitext-2 dataset into a Corpus, again following the preprocessing from the word language model example.

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ['<eos>']
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)

        return ids

model_data_filepath = 'data/'

corpus = Corpus(model_data_filepath + 'wikitext-2')

3. Load the pre-trained model

This is a tutorial on dynamic quantization, a quantization technique that is applied after a model has been trained. Therefore, we’ll simply load some pre-trained weights into this model architecture; these weights were obtained by training for five epochs using the default settings in the word language model example.

ntokens = len(corpus.dictionary)

model = LSTMModel(
    ntoken = ntokens,
    ninp = 512,
    nhid = 256,
    nlayers = 5,
)

model.load_state_dict(
    torch.load(
        model_data_filepath + 'word_language_model_quantize.pth',
        map_location=torch.device('cpu')
        )
    )

model.eval()
print(model)

Out:

LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): Linear(in_features=256, out_features=33278, bias=True)
)

Now let’s generate some text to ensure that the pre-trained model is working properly - similarly to before, we follow here

input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000

with open(model_data_filepath + 'out.txt', 'w') as outf:
    with torch.no_grad():  # no tracking history
        for i in range(num_words):
            output, hidden = model(input_, hidden)
            word_weights = output.squeeze().div(temperature).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input_.fill_(word_idx)

            word = corpus.dictionary.idx2word[word_idx]

            outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))

            if i % 100 == 0:
                print('| Generated {}/{} words'.format(i, 1000))

with open(model_data_filepath + 'out.txt', 'r') as outf:
    all_output = outf.read()
    print(all_output)

Out:

| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'the' b'show' b'as' b'fact' b';' b'from' b'example' b'.' b'In' b'Washington' b',' b'leg' b'models' b'formed' b'from' b'America' b'of' b'total' b'of' b'1649'
b'.' b'They' b'associated' b'their' b'tribute' b'to' b'<unk>' b'the' b'enemy' b',' b'with' b'the' b'wettest' b'forming' b'star' b'or' b'back' b'(' b'modern' b'243'
b'million' b'N' b'and' b'other' b'astronomers' b')' b',' b'in' b'January' b'2001' b'.' b'The' b'largest' b'global' b'number' b'of' b'Michigan' b'are' b'Arabic' b','
b'with' b'only' b'40' b'%' b',' b'with' b'various' b'native' b'representative' b'with' b'broken' b'work' b'.' b'Despite' b'less' b'than' b'5' b'%' b'of' b'larger'
b'school' b',' b'eight' b'females' b'remained' b'backstage' b'damage' b'(' b'each' b'presently' b'greater' b'than' b'1' b'@.@' b'2' b'years' b',' b'and' b'most' b'other'
b'discoveries' b'exhibit' b'a' b'Global' b'villagers' b'All' b')' b'.' b'It' b'was' b'often' b'first' b'until' b'repaired' b"'" b'"' b'Institution' b'for' b'which' b'they'
b'vary' b'in' b'the' b'random' b'fight' b'of' b'criminal' b'properties' b'"' b',' b'in' b'order' b'to' b'have' b'been' b'a' b'thousands' b'of' b'expansion' b'that'
b'are' b'goats' b'known' b'as' b'Brienne' b',' b'and' b'their' b'main' b'failures' b'of' b'officials' b'pose' b'in' b'224' b'areas' b'.' b'<eos>' b'Major' b'expansions'
b'permitted' b'numbers' b'from' b'pre' b'@-@' b'stripping' b'examples' b',' b'with' b'them' b'to' b'be' b'kept' b'.' b'Although' b'they' b'occur' b'from' b'later' b','
b'news' b'from' b'the' b'cultivation' b'of' b'particularly' b'150' b'road' b'individuals' b'under' b'Venus' b',' b'after' b'particularly' b'allowing' b'Linus' b'before' b'rapid' b'sites' b'.'
b'This' b'claim' b'it' b'is' b'a' b'composite' b'.' b'The' b'<unk>' b'leaves' b'willows' b'to' b'access' b'for' b'soil' b',' b'so' b'it' b'makes' b'an'
b'handedness' b'or' b'provide' b'obstacles' b'control' b'of' b'symptomatic' b'.' b'It' b'Waterfall' b'makes' b'that' b'he' b'should' b'be' b'three' b'separate' b'to' b'earn' b'forest'
b'so' b'shaped' b',' b'or' b'so' b'seems' b'to' b'be' b'better' b'at' b'the' b'back' b'of' b'some' b'types' b'of' b'quantities' b'.' b'<eos>' b'suburbs'
b'for' b'<unk>' b'chicks' b',' b'when' b'they' b'established' b'land' b',' b'jaw' b'or' b'the' b'Spike' b'or' b'wicket' b'XeF' b'are' b'elegant' b'periodically' b'.'
b'The' b'head' b'beads' b'(' b'most' b'of' b'which' b'to' b'honor' b'of' b'Loyola' b'include' b'<unk>' b',' b'and' b'is' b'hard' b'whether' b'they' b'were'
b'originally' b'defined' b'.' b'<unk>' b')' b',' b'also' b'check' b'seeks' b'to' b'score' b'that' b'they' b'would' b'be' b'represented' b'on' b'events' b'.' b'However'
b',' b'there' b'can' b'have' b'been' b'able' b'to' b'have' b'ages' b'@.@' b'92' b'in' b'Bobcats' b'.' b'Even' b'thus' b'they' b'survive' b'they' b'are'
b'certain' b'when' b'that' b',' b'they' b'have' b'high' b'Irish' b'defenses' b',' b'some' b'or' b'simply' b'primitive' b'needs' b'to' b'find' b'those' b'or' b'moving'
b'Adventure' b'State' b'.' b'There' b'were' b'8' b'lb' b',' b'and' b'may' b'<unk>' b'be' b'actually' b'.' b'It' b'is' b'always' b'possible' b'by' b'Concerned'
b'and' b'inferred' b'enough' b'much' b'away' b'them' b',' b'though' b'the' b'pair' b'do' b'not' b'secure' b'to' b'have' b'reduced' b'.' b'Only' b'eggs' b'makes'
b'high' b'debate' b'within' b'only' b'750' b'hurricanes' b',' b'so' b'whether' b'they' b'do' b'not' b'seem' b'to' b'communicate' b'there' b'back' b'to' b'Atlas' b','
b'so' b'males' b'however' b'are' b'still' b'valued' b'.' b'Ancient' b'Spalato' b'are' b'lying' b'to' b'indentured' b'predators' b'such' b'as' b'<unk>' b',' b'noises' b'or'
b'rivers' b'.' b'They' b'are' b'known' b'in' b'the' b'same' b'category' b'until' b'the' b'latter' b'has' b'anything' b'cell' b',' b'map' b',' b'contract' b'"'
b'or' b'a' b'just' b'sufficient' b'earthstar' b',' b'to' b'be' b'active' b'by' b'it' b'.' b'"' b'is' b',' b'from' b'present' b'collapse' b'now' b'less'
b'recent' b'and' b'iconography' b'if' b'Ceres' b'sounded' b'some' b'theories' b'and' b'were' b'repelled' b'by' b'some' b'bubble' b'systems' b'.' b'These' b'birds' b'describes' b'Chinese'
b'belt' b'railway' b'birds' b'which' b'winter' b'chromatin' b'continue' b'to' b'mimic' b'for' b'their' b'abbreviation' b'.' b'<eos>' b'Common' b'starlings' b'with' b'from' b'Jude' b'stars'
b'of' b'a' b'variety' b'is' b'limited' b'to' b'a' b'average' b'production' b'.' b'Such' b'bodies' b'may' b'be' b'<unk>' b'dead' b',' b'and' b'they' b'transfer'
b'their' b'workforce' b'and' b'carrying' b'them' b'three' b'.' b'In' b'ways' b'they' b'reflected' b'their' b'beak' b'all' b'of' b'the' b'division' b"'s" b'seemingly' b'Flag'
b'out' b'of' b'constitute' b'multiple' b'animals' b'to' b'hierarchical' b'properties' b'.' b'Wintory' b'victim' b'mastery' b',' b'but' b'continue' b'varies' b'are' b'rare' b'.' b'Perhaps'
b'after' b'their' b'parent' b'crown' b'star' b',' b'they' b'arises' b'southeast' b'off' b'the' b'nest' b'or' b'<unk>' b'once' b'sort' b'of' b'coffee' b',' b'caused'
b'that' b'they' b'have' b'deliberately' b'largely' b'demonstrated' b'Ceres' b',' b'a' b'<unk>' b'<unk>' b'resulting' b'on' b'its' b'diet' b'regarding' b'them' b'instead' b'for' b'nest'
b'them' b'.' b'These' b'writers' b'needs' b'to' b'be' b'spelled' b'way' b',' b'their' b'attempts' b'even' b'looks' b'to' b'Australia' b',' b'they' b'became' b'both'
b'chicks' b'.' b'In' b'his' b'dwarf' b',' b'<unk>' b'rudder' b'do' b'not' b'be' b'got' b'when' b'people' b'would' b'be' b'able' b'to' b'pay' b'.'
b'<unk>' b'else' b',' b'computational' b',' b'overly' b'flocks' b'regulate' b'its' b'relative' b'traveling' b'and' b'also' b'nest' b'offer' b',' b'and' b'their' b'male' b'techniques'
b',' b'fir' b',' b'<unk>' b',' b'brownish' b',' b'and' b'access' b'long' b'.' b'<eos>' b'Common' b'starlings' b'can' b'<unk>' b'the' b'increases' b'on' b'rotation'
b'feathers' b',' b'even' b'other' b'range' b'is' b'genetically' b'(' b'12' b'and' b'distances' b')' b',' b'or' b'loud' b'depth' b'.' b'<eos>' b'Common' b'starlings'
b'Orilla' b'flocks' b'often' b'feed' b'in' b'fresh' b'areas' b'.' b'<eos>' b'Unlike' b'first' b'value' b'Forks' b'can' b'be' b'found' b'to' b'find' b'limited' b'<unk>'
b'down' b'fewer' b'minutes' b',' b'Jifna' b'could' b'be' b'given' b'<unk>' b'24' b'\xe2\x80\x93' b'21' b'.' b'It' b'includes' b'produces' b'natural' b'explorers' b'that' b'could'
b'be' b'important' b'to' b'the' b'island' b'.' b'This' b'travel' b'certain' b'times' b'on' b'98' b'October' b'in' b'<unk>' b'.' b'<eos>' b'All' b'are' b'bluebunch'
b'(' b'e.g.' b'<unk>' b'may' b'be' b'composed' b')' b'but' b'usually' b'together' b',' b'with' b'experience' b'the' b'arrangement' b'promote' b'or' b'<unk>' b'over' b'them'
b'of' b'some' b'or' b'Usta\xc5\xa1e' b'takes' b'by' b'other' b'slaves' b'.' b'Such' b'of' b'175' b'birds' b'will' b'be' b'preferred' b'by' b'Heinkel' b'that' b'they'
b'feed' b'with' b'the' b'purpose' b'for' b'users' b',' b'so' b'when' b'it' b'will' b'interact' b',' b'they' b'were' b'only' b'brownish' b'.' b'Mhalsa' b'such'
b'as' b'symbolize' b'food' b'and' b'foraging' b'on' b'Destiny' b"'s" b'observation' b';' b'they' b'heard' b'Iberia' b',' b'<unk>' b'however' b',' b'agents' b',' b'pedestrians'
b',' b'<unk>' b',' b'which' b'raised' b',' b'devices' b'or' b'female' b'gallop' b'to' b'sites' b',' b'as' b'well' b'as' b'west' b'race' b',' b'neighbouring'
b'supply' b'<unk>' b',' b'Puerto' b'M.' b'Ragnaill' b',' b'and' b'outer' b'starling' b',' b'cores' b',' b'and' b'some' b'cook' b'.' b'If' b'females' b'may'
b'be' b'tied' b'<unk>' b',' b'kill' b'Lemieux' b'<unk>' b'are' b'open' b'by' b'their' b'relative' b'materials' b'.' b'In' b'the' b'upper' b',' b'Irish' b'bodies'
b'are' b'valid' b',' b'showing' b'cannot' b'white' b';' b'it' b'likes' b'emit' b'<unk>' b'eggs' b'throughout' b'Plateau' b'.' b'A' b'small' b'parasites' b'713' b'in'
b'various' b'@-@' b'gear' b'areas' b',' b'contracting' b'out' b'of' b'[' b'they' b'produces' b'their' b'eggs' b'from' b'Celtic' b'colors' b',' b'reagent' b'rate' b','
b'TBSA' b',' b'and' b'<unk>' b'nests' b'.' b'Members' b'activity' b'early' b'to' b'Developers' b',' b'Timor' b'or' b'Emma' b'can' b'begin' b'ill' b',' b'animal'
b'complications' b'each' b'entirely' b'or' b'some' b'chicks' b'have' b'experienced' b'sides' b'to' b'enhance' b'.' b'There' b'were' b'no' b'half' b'rare' b'motion' b'differences' b'to'

It’s no GPT-2, but it looks like the model has started to learn the structure of language!

We’re almost ready to demonstrate dynamic quantization. We just need to define a few more helper functions:

bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1

# create test data set
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    return data.view(bsz, -1).t().contiguous()

test_data = batchify(corpus.test, eval_batch_size)

# Evaluation functions
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

def repackage_hidden(h):
  """Wraps hidden states in new Tensors, to detach them from their history."""

  if isinstance(h, torch.Tensor):
      return h.detach()
  else:
      return tuple(repackage_hidden(v) for v in h)

def evaluate(model_, data_source):
    # Turn on evaluation mode which disables dropout.
    model_.eval()
    total_loss = 0.
    hidden = model_.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model_(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

4. Test dynamic quantization

Finally, we can call torch.quantization.quantize_dynamic on the model! Specifically,

  • We specify that we want the nn.LSTM and nn.Linear modules in our model to be quantized
  • We specify that we want weights to be converted to int8 values
import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

Out:

LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, scale=1.0, zero_point=0)
)

The model looks the same; how has this benefited us? First, we see a significant reduction in model size:

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)

Out:

Size (MB): 113.941574
Size (MB): 76.80671

Second, we see faster inference time, with no difference in evaluation loss:

Note: we number of threads to one for single threaded comparison, since quantized models run single threaded.

torch.set_num_threads(1)

def time_model_evaluation(model, test_data):
    s = time.time()
    loss = evaluate(model, test_data)
    elapsed = time.time() - s
    print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))

time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)

Out:

loss: 5.167
elapsed time (seconds): 245.7
loss: 5.168
elapsed time (seconds): 172.9

Running this locally on a MacBook Pro, without quantization, inference takes about 200 seconds, and with quantization it takes just about 100 seconds.

Conclusion

Dynamic quantization can be an easy way to reduce model size while only having a limited effect on accuracy.

Thanks for reading! As always, we welcome any feedback, so please create an issue here if you have any.

Total running time of the script: ( 7 minutes 3.136 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources