# Custom Hardware Plugins PyTorch/XLA supports custom hardware through OpenXLA's PJRT C API. The PyTorch/XLA team direclty supports plugins for Cloud TPU (`libtpu`) and GPU ([OpenXLA](https://github.com/openxla/xla/tree/main/xla/pjrt/gpu)). The same plugins may also be used by JAX and TF. ## Implementing a PJRT Plugin PJRT C API plugins may be closed-source or open-source. They contain two parts: 1. Binary exposing a PJRT C API implementation. This part can be shared with JAX and TensorFlow. 2. Python package containing the above binary, as well as an implementation of our `DevicePlugin` Python interface, which handles additional setup. ### PJRT C API Implementation In short, you must implement a [PjRtClient](https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h) containing an XLA compiler and runtime for your device. The PJRT C++ interface is mirrored in C in the [PJRT_Api](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api.h). The most straightforward option is to implement your plugin in C++ and [wrap it](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_wrapper_impl.h) as a C API implementation. This process is explained in detail in [OpenXLA's documentation](https://openxla.org/xla/pjrt_integration#how_to_integrate_with_pjrt). For a concrete example, see the [example implementation](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_cpu_internal.cc). ### PyTorch/XLA Plugin Package At this point, you should have a functional PJRT plugin binary, which you can test with the placeholder `LIBRARY` device type. For example: $ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python >>> import torch_xla >>> torch_xla.devices() # Assuming there are 4 devices. Your hardware may differ. [device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)] To register your device type automatically for users as well as to handle extra setup for e.g. multiprocessing, you may implement the `DevicePlugin` Python API. PyTorch/XLA plugin packages contain two key components: 1. An implementation of `DevicePlugin` that (at the very least) provides the path to your plugin binary. For example: ``` python class CpuPlugin(plugins.DevicePlugin): def library_path(self) -> str: return os.path.join( os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so') ``` 2. A `torch_xla.plugins` [entry point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html) that identifies your `DevicePlugin`. For exmaple, to register the `EXAMPLE` device type in a `pyproject.toml`: ```{=html} <!-- --> ``` [project.entry-points."torch_xla.plugins"] example = "torch_xla_cpu_plugin:CpuPlugin" With your package installed, you may then use your `EXAMPLE` device directly: $ PJRT_DEVICE=EXAMPLE python >>> import torch_xla >>> torch_xla.devices() [device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)] [DevicePlugin](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/plugins.py) provides additional extension points for multiprocess initialization and client options. The API is currently in an experimental state, but it is expected to become stable in a future release.