{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# SST-2 Binary text classification with XLM-RoBERTa model\n\n**Author**: [Parmeet Bhatia](parmeetbhatia@fb.com)_\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n\nThis tutorial demonstrates how to train a text classifier on SST-2 binary dataset using a pre-trained XLM-RoBERTa (XLM-R) model.\nWe will show how to use torchtext library to:\n\n1. build text pre-processing pipeline for XLM-R model\n2. read SST-2 dataset and transform it using text and label transformation\n3. instantiate classification model using pre-trained XLM-R encoder\n\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Common imports\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\n\nDEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Transformation\n\nModels like XLM-R cannot work directly with raw text. The first step in training\nthese models is to transform input text into tensor (numerical) form such that it\ncan then be processed by models to make predictions. A standard way to process text is:\n\n1. Tokenize text\n2. Convert tokens into (integer) IDs\n3. Add any special tokens IDs\n\nXLM-R uses sentencepiece model for text tokenization. Below, we use pre-trained sentencepiece\nmodel along with corresponding vocabulary to build text pre-processing pipeline using torchtext's transforms.\nThe transforms are pipelined using :py:func:`torchtext.transforms.Sequential` which is similar to :py:func:`torch.nn.Sequential`\nbut is torchscriptable. Note that the transforms support both batched and non-batched text inputs i.e, one\ncan either pass a single sentence or list of sentences.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torchtext.transforms as T\nfrom torch.hub import load_state_dict_from_url\n\npadding_idx = 1\nbos_idx = 0\neos_idx = 2\nmax_seq_len = 256\nxlmr_vocab_path = r\"https://download.pytorch.org/models/text/xlmr.vocab.pt\"\nxlmr_spm_model_path = r\"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model\"\n\ntext_transform = T.Sequential(\n T.SentencePieceTokenizer(xlmr_spm_model_path),\n T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),\n T.Truncate(max_seq_len - 2),\n T.AddToken(token=bos_idx, begin=True),\n T.AddToken(token=eos_idx, begin=False),\n)\n\n\nfrom torch.utils.data import DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternately we can also use transform shipped with pre-trained model that does all of the above out-of-the-box\n\n::\n\n text_transform = XLMR_BASE_ENCODER.transform()\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\ntorchtext provides several standard NLP datasets. For complete list, refer to documentation\nat https://pytorch.org/text/stable/datasets.html. These datasets are build using composable torchdata\ndatapipes and hence support standard flow-control and mapping/transformation using user defined functions\nand transforms. Below, we demonstrate how to use text and label processing transforms to pre-process the\nSST-2 dataset.\n\n
Using datapipes is still currently subject to a few caveats. If you wish\n to extend this example to include shuffling, multi-processing, or\n distributed learning, please see `this note