Torch Export to StableHLO
This document describes how to use torch export + torch xla to export to StableHLO format.
There are 2 ways to accomplish this:
First do torch.export to create a ExportedProgram, which contains the program in torch.fx graph. Then use
exported_program_to_stablehlo
to convert it into an object that contains stablehlo MLIR code.First convert pytorch model to a jax function, then use jax utilities to convert it to stablehlo
Using torch.export
from torch.export import export
import torchvision
import torch
import torch_xla2 as tx
import torch_xla2.export
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
weights, stablehlo = tx.export.exported_program_to_stablehlo(exported)
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like
The stablehlo object is of type jax.export.Exported
.
Feel free to explore: https://openxla.org/stablehlo/tutorials/jax-export
for more details on how to use the MLIR code generated from it.
Using extract_jax
from torch.export import export
import torchvision
import torch
import torch_xla2 as tx
import torch_xla2.export
import jax
import jax.numpy as jnp
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
weights, jfunc = tx.extract_jax(resnet18)
# Below are APIs from jax
stablehlo = jax.export.export(jax.jit(jfunc))(weights, (jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype)))
print(stablehlo.mlir_module())
# Can store weights and/or stablehlo object however you like
The second to last line we used jax.ShapedDtypeStruct
to specify the input shape.
You can also pass a numpy array here.
Preserving High-Level PyTorch Operations in StableHLO by generating stablehlo.composite
High level PyTorch ops (e.g. F.scaled_dot_product_attention
) will be
decomposed into low level ops during PyTorch -> StableHLO lowering.
Capturing the high level op in downstream ML compilers can be crucial
for genearting a performant, efficient specialized kernels. While
pattern matching a bunch of low level ops in the ML compiler can be
challenging and error-prone, we offer a more robust method to outline
the high-level PyTorch op in StableHLO program - by generating
stablehlo.composite
for the high level PyTorch ops.
The following example shows a pratical use case - capturing
scaled_product_attention
For using composite
we need to use the jax-centric export now. (i.e. no torch.export)
We are working in adding support for torch.export now.
import unittest
import torch
import torch.nn.functional as F
from torch.library import Library, impl, impl_abstract
import torch_xla2
import torch_xla2.export
from torch_xla2.ops import jaten
from torch_xla2.ops import jlibrary
# Create a `mylib` library which has a basic SDPA op.
m = Library("mylib", "DEF")
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")
@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd")
def _mylib_scaled_dot_product_attention(q, k, v):
"""Basic scaled dot product attention without all the flags/features."""
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=0,
is_causal=False,
scale=None,
)
return y.transpose(1, 2)
@impl_abstract("mylib::scaled_dot_product_attention")
def _mylib_scaled_dot_product_attention_meta(q, k, v):
return torch.empty_like(q)
# Register library op as a composite for export using the `@impl` method
# for a torch decomposition.
jlibrary.register_torch_composite(
"mylib.scaled_dot_product_attention",
_mylib_scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention.default
)
# Also register ATen softmax as a composite for export in the `mylib` library
# using the JAX ATen decomposition from `jaten`.
jlibrary.register_jax_composite(
"mylib.softmax",
jaten._aten_softmax,
torch.ops.aten._softmax,
static_argnums=1 # Required by JAX jit
)
class LibraryTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False
def test_basic_sdpa_library(self):
class CustomOpExample(torch.nn.Module):
def forward(self, q,k,v):
x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
x = x + 1
return x
# Export and check for composite operations
model = CustomOpExample()
arg = torch.rand(32, 8, 128, 64)
args = (arg, arg, arg, )
exported = torch.export.export(model, args=args)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
## TODO Update this machinery from producing function calls to producing
## stablehlo.composite ops.
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
self.assertIn("call @mylib.softmax", module_str)
if __name__ == '__main__':
unittest.main()
As we see, to emit a stablehlo function into composite, first we make a python function
representing the region of code that we want to call, then, we register it
so that pytorch and jlibrary understands it’s a custom region. Then, th
emitted Stablehlo will have mylib.scaled_dot_product_attention
and mylib.softmax
outlined stablehlo functions.