• Docs >
  • Masked Language Modeling (MLM) with Hugging Face BERT Transformer
Shortcuts
[1]:
# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

8895119d8f634fe5b8ca2b679a440289

Masked Language Modeling (MLM) with Hugging Face BERT Transformer

Learning objectives

This notebook demonstrates the steps for compiling a TorchScript module with Torch-TensorRT on a pretrained BERT transformer from Hugging Face, and running it to test the speedup obtained.

Contents

  1. Requirements

  2. BERT Overview

  3. Creating TorchScript modules

  4. Compiling with Torch-TensorRT

  5. Benchmarking

  6. Conclusion

## 1. Requirements

NVIDIA’s NGC provides a PyTorch Docker Container which contains PyTorch and Torch-TensorRT. Starting with version 22.05-py3, we can make use of latest pytorch container to run this notebook.

Otherwise, you can follow the steps in notebooks/README to prepare a Docker container yourself, within which you can run this demo notebook.

[2]:
!pip install transformers
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Requirement already satisfied: transformers in /opt/conda/lib/python3.8/site-packages (4.18.0)
Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.8/site-packages (from transformers) (4.63.0)
Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.8/site-packages (from transformers) (2022.3.15)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /opt/conda/lib/python3.8/site-packages (from transformers) (0.5.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /opt/conda/lib/python3.8/site-packages (from transformers) (0.12.1)
Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.8/site-packages (from transformers) (1.22.3)
Requirement already satisfied: sacremoses in /opt/conda/lib/python3.8/site-packages (from transformers) (0.0.49)
Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from transformers) (2.27.1)
Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.8/site-packages (from transformers) (6.0)
Requirement already satisfied: filelock in /opt/conda/lib/python3.8/site-packages (from transformers) (3.6.0)
Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.8/site-packages (from transformers) (21.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.1.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging>=20.0->transformers) (3.0.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->transformers) (1.26.8)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests->transformers) (2.0.12)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->transformers) (2021.10.8)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->transformers) (3.3)
Requirement already satisfied: six in /opt/conda/lib/python3.8/site-packages (from sacremoses->transformers) (1.16.0)
Requirement already satisfied: click in /opt/conda/lib/python3.8/site-packages (from sacremoses->transformers) (8.0.4)
Requirement already satisfied: joblib in /opt/conda/lib/python3.8/site-packages (from sacremoses->transformers) (1.1.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
[3]:
from transformers import BertTokenizer, BertForMaskedLM
import torch
import timeit
import numpy as np
import torch_tensorrt
import torch.backends.cudnn as cudnn

## 2. BERT Overview

Transformers comprise a class of deep learning algorithms employing self-attention; broadly speaking, the models learn large matrices of numbers, each element of which denotes how important one component of input data is to another. Since their introduction in 2017, transformers have enjoyed widespread adoption, particularly in natural language processing, but also in computer vision problems. This is largely because they are easier to parallelize than the sequence models which attention mechanisms were originally designed to augment.

Hugging Face is a company that maintains a huge respository of pre-trained transformer models. The company also provides tools for integrating those models into PyTorch code and running inference with them.

One of the most popular transformer models is BERT (Bidirectional Encoder Representations from Transformers). First developed at Google and released in 2018, it has become the backbone of Google’s search engine and a standard benchmark for NLP experiments. BERT was originally trained for next sentence prediction and masked language modeling (MLM), which aims to predict hidden words in sentences. In this notebook, we will use Hugging Face’s bert-base-uncased model (BERT’s smallest and simplest form, which does not employ text capitalization) for MLM.

## 3. Creating TorchScript modules

First, create a pretrained BERT tokenizer from the bert-base-uncased model

[4]:
enc = BertTokenizer.from_pretrained('bert-base-uncased')

Create dummy inputs to generate a traced TorchScript model later

[5]:
batch_size = 4

batched_indexed_tokens = [[101, 64]*64]*batch_size
batched_segment_ids = [[0, 1]*64]*batch_size
batched_attention_masks = [[1, 1]*64]*batch_size

tokens_tensor = torch.tensor(batched_indexed_tokens)
segments_tensor = torch.tensor(batched_segment_ids)
attention_masks_tensor = torch.tensor(batched_attention_masks)

Obtain a BERT masked language model from Hugging Face in the (scripted) TorchScript, then use the dummy inputs to trace it

[6]:
mlm_model_ts = BertForMaskedLM.from_pretrained('bert-base-uncased', torchscript=True)
traced_mlm_model = torch.jit.trace(mlm_model_ts, [tokens_tensor, segments_tensor, attention_masks_tensor])
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

Define 4 masked sentences, with 1 word in each sentence hidden from the model. Fluent English speakers will probably be able to guess the masked words, but just in case, they are 'capital', 'language', 'innings', and 'mathematics'.

Also create a list containing the position of the masked word within each sentence. Given Python’s 0-based indexing convention, the numbers are each higher by 1 than might be expected. This is because the token at index 0 in each sentence is a beginning-of-sentence token, denoted [CLS] when entered explicitly.

[7]:
masked_sentences = ['Paris is the [MASK] of France.',
                    'The primary [MASK] of the United States is English.',
                    'A baseball game consists of at least nine [MASK].',
                    'Topology is a branch of [MASK] concerned with the properties of geometric objects that remain unchanged under continuous transformations.']
pos_masks = [4, 3, 9, 6]

Pass the masked sentences into the (scripted) TorchScript MLM model and verify that the unmasked sentences yield the expected results.

Because the sentences are of different lengths, we must specify the padding argument in calling our encoder/tokenizer. There are several possible padding strategies, but we’ll use 'max_length' padding with max_length=128. Later, when we compile an optimized version of the model with Torch-TensorRT, the optimized model will expect inputs of length 128, hence our choice of padding strategy and length here.

[8]:
encoded_inputs = enc(masked_sentences, return_tensors='pt', padding='max_length', max_length=128)
outputs = mlm_model_ts(**encoded_inputs)
most_likely_token_ids = [torch.argmax(outputs[0][i, pos, :]) for i, pos in enumerate(pos_masks)]
unmasked_tokens = enc.decode(most_likely_token_ids).split(' ')
unmasked_sentences = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens)]
for sentence in unmasked_sentences:
    print(sentence)
Paris is the capital of France.
The primary language of the United States is English.
A baseball game consists of at least nine innings.
Topology is a branch of mathematics concerned with the properties of geometric objects that remain unchanged under continuous transformations.

Pass the masked sentences into the traced MLM model and verify that the unmasked sentences yield the expected results.

Note the difference in how the encoded_inputs are passed into the model in the following cell compared to the previous one. If you examine encoded_inputs, you’ll find that it’s a dictionary with 3 keys, 'input_ids', 'token_type_ids', and 'attention_mask', each with a PyTorch tensor as an associated value. The traced model will accept **encoded_inputs as an input, but the Torch-TensorRT-optimized model (to be defined later) will not.

[9]:
encoded_inputs = enc(masked_sentences, return_tensors='pt', padding='max_length', max_length=128)
outputs = traced_mlm_model(encoded_inputs['input_ids'], encoded_inputs['token_type_ids'], encoded_inputs['attention_mask'])
most_likely_token_ids = [torch.argmax(outputs[0][i, pos, :]) for i, pos in enumerate(pos_masks)]
unmasked_tokens = enc.decode(most_likely_token_ids).split(' ')
unmasked_sentences = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens)]
for sentence in unmasked_sentences:
    print(sentence)
Paris is the capital of France.
The primary language of the United States is English.
A baseball game consists of at least nine innings.
Topology is a branch of mathematics concerned with the properties of geometric objects that remain unchanged under continuous transformations.

## 4. Compiling with Torch-TensorRT

Change the logging level to avoid long printouts

[10]:
new_level = torch_tensorrt.logging.Level.Error
torch_tensorrt.logging.set_reportable_log_level(new_level)

Compile the model

[11]:
trt_model = torch_tensorrt.compile(traced_mlm_model,
    inputs= [torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # input_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # token_type_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32)], # attention_mask
    enabled_precisions= {torch.float32}, # Run with 32-bit precision
    workspace_size=2000000000,
    truncate_long_and_double=True
)

Pass the masked sentences into the compiled model and verify that the unmasked sentences yield the expected results.

[12]:
enc_inputs = enc(masked_sentences, return_tensors='pt', padding='max_length', max_length=128)
enc_inputs = {k: v.type(torch.int32).cuda() for k, v in enc_inputs.items()}
output_trt = trt_model(enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])
most_likely_token_ids_trt = [torch.argmax(output_trt[i, pos, :]) for i, pos in enumerate(pos_masks)]
unmasked_tokens_trt = enc.decode(most_likely_token_ids_trt).split(' ')
unmasked_sentences_trt = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens_trt)]
for sentence in unmasked_sentences_trt:
    print(sentence)
Paris is the capital of France.
The primary language of the United States is English.
A baseball game consists of at least nine innings.
Topology is a branch of mathematics concerned with the properties of geometric objects that remain unchanged under continuous transformations.

Compile the model again, this time with 16-bit precision

[13]:
trt_model_fp16 = torch_tensorrt.compile(traced_mlm_model,
    inputs= [torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # input_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # token_type_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32)], # attention_mask
    enabled_precisions= {torch.half}, # Run with 16-bit precision
    workspace_size=2000000000,
    truncate_long_and_double=True
)

## 5. Benchmarking

In developing this notebook, we conducted our benchmarking on a single NVIDIA A100 GPU. Your results may differ from those shown, particularly on a different GPU.

This function passes the inputs into the model and runs inference num_loops times, then returns a list of length containing the amount of time in seconds that each instance of inference took.

[14]:
def timeGraph(model, input_tensor1, input_tensor2, input_tensor3, num_loops=50):
    print("Warm up ...")
    with torch.no_grad():
        for _ in range(20):
            features = model(input_tensor1, input_tensor2, input_tensor3)

    torch.cuda.synchronize()

    print("Start timing ...")
    timings = []
    with torch.no_grad():
        for i in range(num_loops):
            start_time = timeit.default_timer()
            features = model(input_tensor1, input_tensor2, input_tensor3)
            torch.cuda.synchronize()
            end_time = timeit.default_timer()
            timings.append(end_time - start_time)
            # print("Iteration {}: {:.6f} s".format(i, end_time - start_time))

    return timings

This function prints the number of input batches the model is able to process each second and summary statistics of the model’s latency.

[15]:
def printStats(graphName, timings, batch_size):
    times = np.array(timings)
    steps = len(times)
    speeds = batch_size / times
    time_mean = np.mean(times)
    time_med = np.median(times)
    time_99th = np.percentile(times, 99)
    time_std = np.std(times, ddof=0)
    speed_mean = np.mean(speeds)
    speed_med = np.median(speeds)

    msg = ("\n%s =================================\n"
            "batch size=%d, num iterations=%d\n"
            "  Median text batches/second: %.1f, mean: %.1f\n"
            "  Median latency: %.6f, mean: %.6f, 99th_p: %.6f, std_dev: %.6f\n"
            ) % (graphName,
                batch_size, steps,
                speed_med, speed_mean,
                time_med, time_mean, time_99th, time_std)
    print(msg)
[16]:
cudnn.benchmark = True

Benchmark the (scripted) TorchScript model on GPU

[17]:
timings = timeGraph(mlm_model_ts.cuda(), enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

printStats("BERT", timings, batch_size)
Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 599.1, mean: 597.6
  Median latency: 0.006677, mean: 0.006693, 99th_p: 0.006943, std_dev: 0.000059

Benchmark the traced model on GPU

[18]:
timings = timeGraph(traced_mlm_model.cuda(), enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

printStats("BERT", timings, batch_size)
Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 951.2, mean: 951.0
  Median latency: 0.004205, mean: 0.004206, 99th_p: 0.004256, std_dev: 0.000015

Benchmark the compiled FP32 model on GPU

[19]:
timings = timeGraph(trt_model, enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

printStats("BERT", timings, batch_size)
Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 1216.9, mean: 1216.4
  Median latency: 0.003287, mean: 0.003289, 99th_p: 0.003317, std_dev: 0.000007

Benchmark the compiled FP16 model on GPU

[20]:
timings = timeGraph(trt_model_fp16, enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

printStats("BERT", timings, batch_size)
Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 1776.7, mean: 1771.1
  Median latency: 0.002251, mean: 0.002259, 99th_p: 0.002305, std_dev: 0.000015

## 6. Conclusion

In this notebook, we have walked through the complete process of compiling TorchScript models with Torch-TensorRT for Masked Language Modeling with Hugging Face’s bert-base-uncased transformer and testing the performance impact of the optimization. With Torch-TensorRT on an NVIDIA A100 GPU, we observe the speedups indicated below. These acceleration numbers will vary from GPU to GPU (as well as implementation to implementation based on the ops used) and we encorage you to try out latest generation of Data center compute cards for maximum acceleration.

Scripted (GPU): 1.0x Traced (GPU): 1.62x Torch-TensorRT (FP32): 2.14x Torch-TensorRT (FP16): 3.15x

What’s next

Now it’s time to try Torch-TensorRT on your own model. If you run into any issues, you can fill them at https://github.com/NVIDIA/Torch-TensorRT. Your involvement will help future development of Torch-TensorRT.

[ ]:

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources