We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When I run the following program
import jax import jax.numpy as jnp from iree.jax import Program, store_global import flax model = flax.linen.Dense( features=1, use_bias=False, dtype=jnp.bfloat16, param_dtype=jnp.bfloat16 ) rng = jax.random.PRNGKey(0) model_state = model.init(rng, jnp.ones((1, 1))) # The generated MLIR module name will be the prefix before Program. class TryTrainStateProgram(Program): _params = Program.export_global( model_state["params"], initialize=True, mutable=True ) def get_params(self): return self._train_state.params with open("/tmp/a.mlir", "w") as f: f.write(str(Program.get_mlir_module(TryTrainStateProgram)))
I got
ValueError: cannot include dtype 'E' in a buffer
After changing jnp.bfloat16 into other types like jnp.float32, the error disappeared.
jnp.bfloat16
jnp.float32
The text was updated successfully, but these errors were encountered:
No branches or pull requests
When I run the following program
I got
After changing
jnp.bfloat16
into other types likejnp.float32
, the error disappeared.The text was updated successfully, but these errors were encountered: