.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/_rendered_examples/dynamo/torch_compile_transformers_example.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials__rendered_examples_dynamo_torch_compile_transformers_example.py: .. _torch_compile_transformer: Compiling a Transformer using torch.compile and TensorRT ============================================================== This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a transformer-based model. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Imports and Model Definition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 12-17 .. code-block:: python import torch import torch_tensorrt from transformers import BertModel .. GENERATED FROM PYTHON SOURCE LINES 18-27 .. code-block:: python # Initialize model with float precision and sample inputs model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] .. GENERATED FROM PYTHON SOURCE LINES 28-30 Optional Input Arguments to `torch_tensorrt.compile` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 30-47 .. code-block:: python # Enabled precision for TensorRT optimization enabled_precisions = {torch.float} # Whether to print verbose logs debug = True # Workspace size for TensorRT workspace_size = 20 << 30 # Maximum number of TRT Engines # (Lower value allows more graph segmentation) min_block_size = 7 # Operations to Run in Torch, regardless of converter support torch_executed_ops = {} .. GENERATED FROM PYTHON SOURCE LINES 48-50 Compilation with `torch.compile` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 50-68 .. code-block:: python # Define backend compilation keyword arguments compilation_kwargs = { "enabled_precisions": enabled_precisions, "debug": debug, "workspace_size": workspace_size, "min_block_size": min_block_size, "torch_executed_ops": torch_executed_ops, } # Build and compile the model with torch.compile, using Torch-TensorRT backend optimized_model = torch.compile( model, backend="torch_tensorrt", options=compilation_kwargs, ) optimized_model(*inputs) .. GENERATED FROM PYTHON SOURCE LINES 69-71 Equivalently, we could have run the above via the convenience frontend, as so: `torch_tensorrt.compile(model, ir="torch_compile", inputs=inputs, **compilation_kwargs)` .. GENERATED FROM PYTHON SOURCE LINES 73-75 Inference ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 75-83 .. code-block:: python # Does not cause recompilation (same batch size as input) new_inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] new_outputs = optimized_model(*new_inputs) .. GENERATED FROM PYTHON SOURCE LINES 84-92 .. code-block:: python # Does cause recompilation (new batch size) new_inputs = [ torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"), torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"), ] new_outputs = optimized_model(*new_inputs) .. GENERATED FROM PYTHON SOURCE LINES 93-95 Cleanup ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 95-99 .. code-block:: python # Finally, we use Torch utilities to clean up the workspace torch._dynamo.reset() .. GENERATED FROM PYTHON SOURCE LINES 100-109 Cuda Driver Error Note ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`, one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052 and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in:: if __name__ == '__main__': compile_engine_and_infer() .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_tutorials__rendered_examples_dynamo_torch_compile_transformers_example.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_compile_transformers_example.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_compile_transformers_example.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_