by Milad Mohammadi, Jack Cao, Shauheen Zahirazami, Joe Spisak, and Jiewen Tan

As we celebrate the release of OpenXLA, PyTorch 2.0, and PyTorch/XLA 2.0, it’s worth taking a step back and sharing where we see it all going in the short to medium term. With PyTorch adoption leading in the AI space and XLA supporting best-in-class compiler features, PyTorch/XLA is well positioned to provide a cutting edge development stack for both model training and inference. To achieve this, we see investments in three main areas:

  • Training Large Models - Large language models (LLM) and diffusion models have quickly risen in popularity and many cutting edge applications today are built on them. Further to this, training these models requires scale and more specifically the ability to train across thousands of accelerators. To achieve this we are investing in features such as AMP for mixed precision training, PjRt for increased runtime performance, SPMD / FSDP for efficient model sharding, Dynamic Shapes to enable new research approaches, faster data loading through Ray and, and a toolchain that packages all of these features together into a seamless workflow. Some of these features are already available in experimental or beta stages, and others are coming up this year with many heavily leveraging the underlying OpenXLA compiler stack.
  • Model Inference - With large models continuing to grow in size and computational cost, deployment becomes the next challenge as these models continue to find their way into applications. With the introduction of Dynamo in the PyTorch 2.0 release, PyTorch/XLA delivers performance competitive inference. We are, however, incorporating additional inference-oriented including model serving support, Dynamo for sharded large models, quantization via Torch.Export and StableHLO.
  • Ecosystem integration - We are expanding integration with Hugging Face and PyTorch Lightning so users can take advantage of upcoming PyTorch/XLA cutting edge features (e.g. FSDP support in Hugging Face) and the downstream OpenXLA features (e.g. Quantization) through familiar APIs.

Additionally, PyTorch/XLA is set to migrate to the open source OpenXLA as its default downstream compiler; allowing the PyTorch community to gain access to a leading, framework-agnostic compiler stack that enjoys industry-wide contribution and innovation. To achieve this, we will begin supporting StableHLO. As a result, OpenXLA will replace the existing TF:XLA dependency, overall streamlining the dependencies and creating leverage from the broader compiler ecosystem. PyTorch/XLA will also sunset the XRT runtime after migration. You can see the resulting high level stack below with the TensorFlow dependency stricken out:

the upcoming PyTorch/XLA features and integrations

Figure: the upcoming PyTorch/XLA features and integrations are illustrated here

We cannot be more excited about what’s ahead for PyTorch/XLA and invite the community to join us. PyTorch/XLA is developed fully in open source so please file issues, submit pull requests, and send RFCs to GitHub such that we can openly collaborate. You can also try out PyTorch/XLA for yourself on various XLA devices including TPUs and GPUs.

The PyTorch/XLA Team at Google