Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: initial prototype of exporting Lux models to Jax #1088

Merged
merged 4 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ benchmarks/results

# Generated by tutorials
pinn_nested_ad.gif
*.mlir
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ export default defineConfig({
text: "Compiling Lux Models",
link: "/manual/compiling_lux_models",
},
{
text: "Exporting Lux Models to Jax",
link: "/manual/exporting_to_jax",
},
],
},
{
Expand Down
144 changes: 144 additions & 0 deletions docs/src/manual/exporting_to_jax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Exporting Lux Models to Jax (via StableHLO)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

!!! danger "Experimental"

This feature is experimental and is subject to change without notice. Additionally,
this feature currently requires some manual setup for interacting with Jax, which we are
working on improving.

In this manual, we will go over how to export Lux models to StableHLO and use
[EnzymeJAX](https://github.com/EnzymeAD/Enzyme-JAX) to run integrate Lux models with
JAX. We assume that users are familiar with
[Reactant compilation of Lux models](@ref reactant-compilation).

```@example exporting_to_stablehlo
using Lux, Reactant, Random

const dev = reactant_device()
```

We simply define a Lux model and generate the stablehlo code using `Reactant.@code_hlo`.

```@example exporting_to_stablehlo
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 10)
)
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev;
```

Generate an example input.

```@example exporting_to_stablehlo
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev;
```

No instead of compiling the model, we will use `Reactant.@code_hlo` to generate the
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
StableHLO code.

```@example exporting_to_stablehlo
hlo_code = @code_hlo model(x, ps, st)
```

Now we just save this into an `mlir` file. Remember that we only need to save the `@main`
function and not the entire module.
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

```@example exporting_to_stablehlo
hlo_string = string(hlo_code)
hlo_string = hlo_string[findfirst("func.func @main", hlo_string)[1]:end]
hlo_string = hlo_string[1:findlast("}\n}", hlo_string)[1]]

open("exported_lux_model.mlir", "w") do io
write(io, hlo_string)
end
```

Now we define a python script to run the model using EnzymeJAX.

```python
from enzyme_ad.jax import primitives

import jax
import jax.numpy as jnp

with open("exported_lux_model.mlir", "r") as file:
code = file.read()


def run_lux_model(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
):
return primitives.ffi_call(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
out_shapes=[
jax.core.ShapedArray([4, 10], jnp.float32),
],
fn="main",
source=code,
lang=primitives.LANG_MHLO,
pipeline_options=primitives.JaXPipeline(""),
)


# Note that all the inputs must be transposed, i.e. if the Lux model has an input of shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally we can probably make a flag to compile/jit/code hlo to not transpose the inputs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe regardless it’s worth explaining why: Julia default uses col major vs jax default uses row major

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally we can probably make a flag to compile/jit/code hlo to not transpose the inputs

I feel this will be very confusing for people who use Python exclusively. For people who use both (Julia and Python), switching is not that hard when converting the model.

I think we can add a function that serializes all the inputs from Julia + mlir code and we have a python function to deserialize it. That would also be handy for pre-trained models and such

But maybe regardless it’s worth explaining why: Julia default uses col major vs jax default uses row major

Agreed

# (28, 28, 1, 4), then the input to the exported Lux model must be of shape (4, 1, 28, 28)
# Input as defined in our exported Lux model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra fun fact, hypothetically if the model weights themselves weren’t traced, it should output MLIR containing the weights and thus one can export inference of a pre trained model too!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a neat way to do it, but will need some additional plumbing on Lux end. Currently if luxlib sees mismatch in devices (CPUDevice vs ReactantDevice) it throws an error. This is very handy if users forgot to move something to GPU and would previously hit a cryptic error deep in the stack.

With Reactant it might be worth reworking some of those things

x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))

# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))

# Run the exported Lux model
print(
jax.jit(run_lux_model)(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
)
)
```
Loading