{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Language Modeling with ``nn.Transformer`` and torchtext\n\nThis is a tutorial on training a model to predict the next word in a sequence using the\n[nn.Transformer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)_ module.\n\nThe PyTorch 1.2 release includes a standard transformer module based on the\npaper [Attention is All You Need](https://arxiv.org/pdf/1706.03762.pdf)_.\nCompared to Recurrent Neural Networks (RNNs), the transformer model has proven\nto be superior in quality for many sequence-to-sequence tasks while being more\nparallelizable. The ``nn.Transformer`` module relies entirely on an attention\nmechanism (implemented as\n[nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)_)\nto draw global dependencies between input and output. The ``nn.Transformer``\nmodule is highly modularized such that a single component (e.g.,\n[nn.TransformerEncoder](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html)_)\ncan be easily adapted/composed.\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the model\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we train a ``nn.TransformerEncoder`` model on a\ncausal language modeling task. Please note that this tutorial does not cover\nthe training of [nn.TransformerDecoder](https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html#torch.nn.TransformerDecoder)_, as depicted in\nthe right half of the diagram above. The language modeling task is to assign a\nprobability for the likelihood of a given word (or a sequence of words)\nto follow a sequence of words. A sequence of tokens are passed to the embedding\nlayer first, followed by a positional encoding layer to account for the order\nof the word (see the next paragraph for more details). The\n``nn.TransformerEncoder`` consists of multiple layers of\n[nn.TransformerEncoderLayer](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html)_.\nAlong with the input sequence, a square attention mask is required because the\nself-attention layers in ``nn.TransformerDecoder`` are only allowed to attend\nthe earlier positions in the sequence. For the language modeling task, any\ntokens on the future positions should be masked. This masking, combined with fact that \nthe output embeddings are offset with later positions ensures that the\npredictions for position i can depend only on the known outputs at positions less than i.\nTo produce a probability distribution over output words, the output of the ``nn.TransformerEncoder``\nmodel is passed through a linear layer to output unnormalized logits.\nThe log-softmax function isn't applied here due to the later use of\n[CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)_,\nwhich requires the inputs to be unnormalized logits.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import math\nimport os\nfrom tempfile import TemporaryDirectory\nfrom typing import Tuple\n\nimport torch\nfrom torch import nn, Tensor\nfrom torch.nn import TransformerEncoder, TransformerEncoderLayer\nfrom torch.utils.data import dataset\n\nclass TransformerModel(nn.Module):\n\n def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,\n nlayers: int, dropout: float = 0.5):\n super().__init__()\n self.model_type = 'Transformer'\n self.pos_encoder = PositionalEncoding(d_model, dropout)\n encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)\n self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n self.embedding = nn.Embedding(ntoken, d_model)\n self.d_model = d_model\n self.linear = nn.Linear(d_model, ntoken)\n\n self.init_weights()\n\n def init_weights(self) -> None:\n initrange = 0.1\n self.embedding.weight.data.uniform_(-initrange, initrange)\n self.linear.bias.data.zero_()\n self.linear.weight.data.uniform_(-initrange, initrange)\n\n def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:\n \"\"\"\n Arguments:\n src: Tensor, shape ``[seq_len, batch_size]``\n src_mask: Tensor, shape ``[seq_len, seq_len]``\n\n Returns:\n output Tensor of shape ``[seq_len, batch_size, ntoken]``\n \"\"\"\n src = self.embedding(src) * math.sqrt(self.d_model)\n src = self.pos_encoder(src)\n if src_mask is None:\n \"\"\"Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').\n Unmasked positions are filled with float(0.0).\n \"\"\"\n src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)\n output = self.transformer_encoder(src, src_mask)\n output = self.linear(output)\n return output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``PositionalEncoding`` module injects some information about the\nrelative or absolute position of the tokens in the sequence. The\npositional encodings have the same dimension as the embeddings so that\nthe two can be summed. Here, we use ``sine`` and ``cosine`` functions of\ndifferent frequencies.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module):\n\n def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n super().__init__()\n self.dropout = nn.Dropout(p=dropout)\n\n position = torch.arange(max_len).unsqueeze(1)\n div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n pe = torch.zeros(max_len, 1, d_model)\n pe[:, 0, 0::2] = torch.sin(position * div_term)\n pe[:, 0, 1::2] = torch.cos(position * div_term)\n self.register_buffer('pe', pe)\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"\n Arguments:\n x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``\n \"\"\"\n x = x + self.pe[:x.size(0)]\n return self.dropout(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load and batch data\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This tutorial uses ``torchtext`` to generate Wikitext-2 dataset.\nTo access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.\n%%"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%%bash\npip install portalocker\npip install torchdata"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The vocab object is built based on the train dataset and is used to numericalize\ntokens into tensors. Wikitext-2 represents rare tokens as ``.\n\nGiven a 1-D vector of sequential data, ``batchify()`` arranges the data\ninto ``batch_size`` columns. If the data does not divide evenly into\n``batch_size`` columns, then the data is trimmed to fit. For instance, with\nthe alphabet as the data (total length of 26) and ``batch_size=4``, we would\ndivide the alphabet into sequences of length 6, resulting in 4 of such sequences.\n\n\\begin{align}\\begin{bmatrix}\n \\text{A} & \\text{B} & \\text{C} & \\ldots & \\text{X} & \\text{Y} & \\text{Z}\n \\end{bmatrix}\n \\Rightarrow\n \\begin{bmatrix}\n \\begin{bmatrix}\\text{A} \\\\ \\text{B} \\\\ \\text{C} \\\\ \\text{D} \\\\ \\text{E} \\\\ \\text{F}\\end{bmatrix} &\n \\begin{bmatrix}\\text{G} \\\\ \\text{H} \\\\ \\text{I} \\\\ \\text{J} \\\\ \\text{K} \\\\ \\text{L}\\end{bmatrix} &\n \\begin{bmatrix}\\text{M} \\\\ \\text{N} \\\\ \\text{O} \\\\ \\text{P} \\\\ \\text{Q} \\\\ \\text{R}\\end{bmatrix} &\n \\begin{bmatrix}\\text{S} \\\\ \\text{T} \\\\ \\text{U} \\\\ \\text{V} \\\\ \\text{W} \\\\ \\text{X}\\end{bmatrix}\n \\end{bmatrix}\\end{align}\n\nBatching enables more parallelizable processing. However, batching means that\nthe model treats each column independently; for example, the dependence of\n``G`` and ``F`` can not be learned in the example above.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from torchtext.datasets import WikiText2\nfrom torchtext.data.utils import get_tokenizer\nfrom torchtext.vocab import build_vocab_from_iterator\n\ntrain_iter = WikiText2(split='train')\ntokenizer = get_tokenizer('basic_english')\nvocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=[''])\nvocab.set_default_index(vocab[''])\n\ndef data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:\n \"\"\"Converts raw text into a flat Tensor.\"\"\"\n data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]\n return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))\n\n# ``train_iter`` was \"consumed\" by the process of building the vocab,\n# so we have to create it again\ntrain_iter, val_iter, test_iter = WikiText2()\ntrain_data = data_process(train_iter)\nval_data = data_process(val_iter)\ntest_data = data_process(test_iter)\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\ndef batchify(data: Tensor, bsz: int) -> Tensor:\n \"\"\"Divides the data into ``bsz`` separate sequences, removing extra elements\n that wouldn't cleanly fit.\n\n Arguments:\n data: Tensor, shape ``[N]``\n bsz: int, batch size\n\n Returns:\n Tensor of shape ``[N // bsz, bsz]``\n \"\"\"\n seq_len = data.size(0) // bsz\n data = data[:seq_len * bsz]\n data = data.view(bsz, seq_len).t().contiguous()\n return data.to(device)\n\nbatch_size = 20\neval_batch_size = 10\ntrain_data = batchify(train_data, batch_size) # shape ``[seq_len, batch_size]``\nval_data = batchify(val_data, eval_batch_size)\ntest_data = batchify(test_data, eval_batch_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Functions to generate input and target sequence\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``get_batch()`` generates a pair of input-target sequences for\nthe transformer model. It subdivides the source data into chunks of\nlength ``bptt``. For the language modeling task, the model needs the\nfollowing words as ``Target``. For example, with a ``bptt`` value of 2,\nwe\u2019d get the following two Variables for ``i`` = 0:\n\n\n\nIt should be noted that the chunks are along dimension 0, consistent\nwith the ``S`` dimension in the Transformer model. The batch dimension\n``N`` is along dimension 1.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"bptt = 35\ndef get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:\n \"\"\"\n Args:\n source: Tensor, shape ``[full_seq_len, batch_size]``\n i: int\n\n Returns:\n tuple (data, target), where data has shape ``[seq_len, batch_size]`` and\n target has shape ``[seq_len * batch_size]``\n \"\"\"\n seq_len = min(bptt, len(source) - 1 - i)\n data = source[i:i+seq_len]\n target = source[i+1:i+1+seq_len].reshape(-1)\n return data, target"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initiate an instance\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model hyperparameters are defined below. The ``vocab`` size is\nequal to the length of the vocab object.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"ntokens = len(vocab) # size of vocabulary\nemsize = 200 # embedding dimension\nd_hid = 200 # dimension of the feedforward network model in ``nn.TransformerEncoder``\nnlayers = 2 # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``\nnhead = 2 # number of heads in ``nn.MultiheadAttention``\ndropout = 0.2 # dropout probability\nmodel = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run the model\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)_\nwith the [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)_\n(stochastic gradient descent) optimizer. The learning rate is initially set to\n5.0 and follows a [StepLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html)_\nschedule. During training, we use [nn.utils.clip_grad_norm\\_](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html)_\nto prevent gradients from exploding.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import time\n\ncriterion = nn.CrossEntropyLoss()\nlr = 5.0 # learning rate\noptimizer = torch.optim.SGD(model.parameters(), lr=lr)\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n\ndef train(model: nn.Module) -> None:\n model.train() # turn on train mode\n total_loss = 0.\n log_interval = 200\n start_time = time.time()\n\n num_batches = len(train_data) // bptt\n for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n data, targets = get_batch(train_data, i)\n output = model(data)\n output_flat = output.view(-1, ntokens)\n loss = criterion(output_flat, targets)\n\n optimizer.zero_grad()\n loss.backward()\n torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n optimizer.step()\n\n total_loss += loss.item()\n if batch % log_interval == 0 and batch > 0:\n lr = scheduler.get_last_lr()[0]\n ms_per_batch = (time.time() - start_time) * 1000 / log_interval\n cur_loss = total_loss / log_interval\n ppl = math.exp(cur_loss)\n print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '\n f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '\n f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')\n total_loss = 0\n start_time = time.time()\n\ndef evaluate(model: nn.Module, eval_data: Tensor) -> float:\n model.eval() # turn on evaluation mode\n total_loss = 0.\n with torch.no_grad():\n for i in range(0, eval_data.size(0) - 1, bptt):\n data, targets = get_batch(eval_data, i)\n seq_len = data.size(0)\n output = model(data)\n output_flat = output.view(-1, ntokens)\n total_loss += seq_len * criterion(output_flat, targets).item()\n return total_loss / (len(eval_data) - 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Loop over epochs. Save the model if the validation loss is the best\nwe've seen so far. Adjust the learning rate after each epoch.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"best_val_loss = float('inf')\nepochs = 3\n\nwith TemporaryDirectory() as tempdir:\n best_model_params_path = os.path.join(tempdir, \"best_model_params.pt\")\n\n for epoch in range(1, epochs + 1):\n epoch_start_time = time.time()\n train(model)\n val_loss = evaluate(model, val_data)\n val_ppl = math.exp(val_loss)\n elapsed = time.time() - epoch_start_time\n print('-' * 89)\n print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '\n f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')\n print('-' * 89)\n\n if val_loss < best_val_loss:\n best_val_loss = val_loss\n torch.save(model.state_dict(), best_model_params_path)\n\n scheduler.step()\n model.load_state_dict(torch.load(best_model_params_path)) # load best model states"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate the best model on the test dataset\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"test_loss = evaluate(model, test_data)\ntest_ppl = math.exp(test_loss)\nprint('=' * 89)\nprint(f'| End of training | test loss {test_loss:5.2f} | '\n f'test ppl {test_ppl:8.2f}')\nprint('=' * 89)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}