• Docs >
  • Custom Hardware Plugins
Shortcuts

Custom Hardware Plugins

PyTorch/XLA supports custom hardware through OpenXLA’s PJRT C API. The PyTorch/XLA team direclty supports plugins for Cloud TPU (libtpu) and GPU (OpenXLA). The same plugins may also be used by JAX and TF.

Implementing a PJRT Plugin

PJRT C API plugins may be closed-source or open-source. They contain two parts:

  1. Binary exposing a PJRT C API implementation. This part can be shared with JAX and TensorFlow.

  2. Python package containing the above binary, as well as an implementation of our DevicePlugin Python interface, which handles additional setup.

PJRT C API Implementation

In short, you must implement a PjRtClient containing an XLA compiler and runtime for your device. The PJRT C++ interface is mirrored in C in the PJRT_Api. The most straightforward option is to implement your plugin in C++ and wrap it as a C API implementation. This process is explained in detail in OpenXLA’s documentation.

For a concrete example, see the example implementation.

PyTorch/XLA Plugin Package

At this point, you should have a functional PJRT plugin binary, which you can test with the placeholder LIBRARY device type. For example:

$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

To register your device type automatically for users as well as to handle extra setup for e.g. multiprocessing, you may implement the DevicePlugin Python API. PyTorch/XLA plugin packages contain two key components:

  1. An implementation of DevicePlugin that (at the very least) provides the path to your plugin binary. For example:

class CpuPlugin(plugins.DevicePlugin):

  def library_path(self) -> str:
    return os.path.join(
        os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
  1. A torch_xla.plugins entry point that identifies your DevicePlugin. For exmaple, to register the EXAMPLE device type in a pyproject.toml:

<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"

With your package installed, you may then use your EXAMPLE device directly:

$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

DevicePlugin provides additional extension points for multiprocess initialization and client options. The API is currently in an experimental state, but it is expected to become stable in a future release.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources