.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/_rendered_examples/dynamo/auto_generate_plugins.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials__rendered_examples_dynamo_auto_generate_plugins.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials__rendered_examples_dynamo_auto_generate_plugins.py:


.. _auto_generate_plugins:

Automatically Generate a Plugin for a Custom Kernel
===================================================================

We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using
the new Python based plugin system in TensorRT 10.7.

Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT
does not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model.
The easiest way to fix lack of support for ops is by adding a decomposition (see:
`Writing lowering passes for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html>`_) - which defines the operator
in terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see:
`Writing converters for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/dynamo_converters.html>`_) - which defines the operator in terms of TensorRT operators.

In some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or
TensorRT cannot support it natively.

For these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding
the performance and resource overhead from a graph break.

Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT <https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html>`_).
As of TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This
plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the
operation in PyTorch to TensorRT.

.. GENERATED FROM PYTHON SOURCE LINES 30-39

Writing Custom Operators in PyTorch
-----------------------------------------

 Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT.
Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch.
with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type
transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it
is necessary to define.


.. GENERATED FROM PYTHON SOURCE LINES 39-88

.. code-block:: python


    from typing import Tuple

    import tensorrt_bindings.plugin as trtp
    import torch
    import torch_tensorrt
    import triton
    import triton.language as tl


    @triton.jit
    def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
        pid = tl.program_id(0)
        # Compute the range of elements that this thread block will work on
        block_start = pid * BLOCK_SIZE
        # Range of indices this thread will handle
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        # Load elements from the X and Y tensors
        x_vals = tl.load(X + offsets)
        y_vals = tl.load(Y + offsets)
        # Perform the element-wise multiplication
        z_vals = x_vals * y_vals * a + b
        # Store the result in Z
        tl.store(Z + offsets, z_vals)


    @torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=())  # type: ignore[misc]
    def elementwise_scale_mul(
        X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
    ) -> torch.Tensor:
        # Ensure the tensors are on the GPU
        assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
        assert X.shape == Y.shape, "Tensors must have the same shape."

        # Create output tensor
        Z = torch.empty_like(X)

        # Define block size
        BLOCK_SIZE = 1024

        # Grid of programs
        grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)

        # Launch the kernel with parameters a and b
        elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)

        return Z



.. GENERATED FROM PYTHON SOURCE LINES 89-91

The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape
in the course of the operation.

.. GENERATED FROM PYTHON SOURCE LINES 91-98

.. code-block:: python



    @torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
    def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
        return x



.. GENERATED FROM PYTHON SOURCE LINES 99-101

Here we use automatic plugin creation feature in Torch-TensorRT which enables plugin registration using
TensorRT QDP APIs

.. GENERATED FROM PYTHON SOURCE LINES 101-124

.. code-block:: python

    torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
        "torchtrt_ex::elementwise_scale_mul"
    )


    # # %%
    # # Generating the Converter
    # # -------------------------------------------------------------------
    # # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
    # # As long as the namespace and names match, the following function will automatically generate the converter for the operation.
    # # If plugins require an output allocator to dynamically allocate output buffers, like data dependent operators, please set requires_output_allocator to True.
    torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
        "torchtrt_ex::elementwise_scale_mul",
        supports_dynamic_shapes=True,
        requires_output_allocator=False,
    )


    # # %%
    # # Above two commands can be replaced with the following single one line:
    # torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True, requires_output_allocator=False)



.. GENERATED FROM PYTHON SOURCE LINES 125-131

Using our converter with a model
-------------------------------------------------------------------

Now we can use our custom operator in a model and compile it with Torch-TensorRT.
We can see that the custom operator is used as one of the operations in the forward pass of the model.
The process of compiling the model at this point is identical to standard Torch-TensorRT usage.

.. GENERATED FROM PYTHON SOURCE LINES 131-155

.. code-block:: python

    class MyModel(torch.nn.Module):  # type: ignore[misc]
        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            z = torch.add(x, y)
            res = torch.ops.torchtrt_ex.elementwise_scale_mul.default(x, z, b=0.5)

            return res


    my_model = MyModel().to("cuda")
    m = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)
    n = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)

    with torch_tensorrt.logging.errors():
        model_trt = torch_tensorrt.compile(
            my_model, inputs=[m, n], debug=True, min_block_size=1
        )
        for i in range(300):
            res = model_trt(m, n)
            assert torch.allclose(res, my_model(m, n))

    print("Ran with custom plugin!")


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)


.. _sphx_glr_download_tutorials__rendered_examples_dynamo_auto_generate_plugins.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: auto_generate_plugins.py <auto_generate_plugins.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: auto_generate_plugins.ipynb <auto_generate_plugins.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_