Custom Hardware Plugins¶
PyTorch/XLA supports custom hardware through OpenXLA’s PJRT C API. The
PyTorch/XLA team direclty supports plugins for Cloud TPU (libtpu
) and
GPU (OpenXLA).
The same plugins may also be used by JAX and TF.
Implementing a PJRT Plugin¶
PJRT C API plugins may be closed-source or open-source. They contain two parts:
Binary exposing a PJRT C API implementation. This part can be shared with JAX and TensorFlow.
Python package containing the above binary, as well as an implementation of our
DevicePlugin
Python interface, which handles additional setup.
PJRT C API Implementation¶
In short, you must implement a PjRtClient containing an XLA compiler and runtime for your device. The PJRT C++ interface is mirrored in C in the PJRT_Api. The most straightforward option is to implement your plugin in C++ and wrap it as a C API implementation. This process is explained in detail in OpenXLA’s documentation.
For a concrete example, see the example implementation.
PyTorch/XLA Plugin Package¶
At this point, you should have a functional PJRT plugin binary, which
you can test with the placeholder LIBRARY
device type. For example:
$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
To register your device type automatically for users as well as to
handle extra setup for e.g. multiprocessing, you may implement the
DevicePlugin
Python API. PyTorch/XLA plugin packages contain two key
components:
An implementation of
DevicePlugin
that (at the very least) provides the path to your plugin binary. For example:
class CpuPlugin(plugins.DevicePlugin):
def library_path(self) -> str:
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
A
torch_xla.plugins
entry point that identifies yourDevicePlugin
. For exmaple, to register theEXAMPLE
device type in apyproject.toml
:
<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"
With your package installed, you may then use your EXAMPLE
device
directly:
$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
DevicePlugin provides additional extension points for multiprocess initialization and client options. The API is currently in an experimental state, but it is expected to become stable in a future release.