Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iree-jax fail when exporting bfloat16 parameters #72

Open
wangkuiyi opened this issue May 25, 2023 · 0 comments
Open

iree-jax fail when exporting bfloat16 parameters #72

wangkuiyi opened this issue May 25, 2023 · 0 comments

Comments

@wangkuiyi
Copy link
Contributor

When I run the following program

import jax
import jax.numpy as jnp
from iree.jax import Program, store_global
import flax

model = flax.linen.Dense(
    features=1, use_bias=False, dtype=jnp.bfloat16, param_dtype=jnp.bfloat16
)
rng = jax.random.PRNGKey(0)
model_state = model.init(rng, jnp.ones((1, 1)))


# The generated MLIR module name will be the prefix before Program.
class TryTrainStateProgram(Program):
    _params = Program.export_global(
        model_state["params"], initialize=True, mutable=True
    )

    def get_params(self):
        return self._train_state.params


with open("/tmp/a.mlir", "w") as f:
    f.write(str(Program.get_mlir_module(TryTrainStateProgram)))

I got

ValueError: cannot include dtype 'E' in a buffer

After changing jnp.bfloat16 into other types like jnp.float32, the error disappeared.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant