PyTorch/XLA documentation
torch_xla
is a Python package that implements XLA as a backend for PyTorch.
Familiar APIs Create and train PyTorch models on TPUs, with only minimal changes required. |
High Performance Scale training jobs across thousands of TPU cores while maintaining high MFU. |
Cost Efficient TPU hardware and the XLA compiler are optimized for cost-efficient training and inference. |
Getting Started
Install with pip.
pip install torch torch_xla[tpu]
Verify the installation:
python -c "import torch_xla; print(torch_xla.__version__)"
python -c "import torch; import torch_xla; print(torch.tensor(1.0, device='xla').device)"
Tutorials
Learn the Basics
Distributed Training on TPU
Advanced Techniques
Troubleshooting
Training on GPU