Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic turning up x64 #33

Closed
anvvalade opened this issue Oct 20, 2024 · 4 comments
Closed

Automatic turning up x64 #33

anvvalade opened this issue Oct 20, 2024 · 4 comments

Comments

@anvvalade
Copy link

The global configuration of jax is altered in the configuration.py file: jax.config.update('jax_enable_x64', True)

This does may clash with user defined 'x64' policies and is a nightmare to debug!

Otherwise: thanks again for the great piece of code!

@eelregit
Copy link
Owner

That's on by default because:

  • We want to make double precision available for numerics, so pmwd aims to be explicit about float and int dtypes wherever it can.
  • By default double precision is used for parameters and cached 1D results, which is important for keeping precision during backprop.
  • I remember that jax/experimental/ode.py is not very stable for single precision (but forgot whether we have fixed that completely or partially? @Yucheng-Zhang )

I am curious what use case you had trouble with?

@anvvalade
Copy link
Author

I understand you have good reasons to do so!

I wrapped pmwd in a custom log probability function, set up work with float32. I then give this function to the blackjax inference module, which internally defines constants, e.g. foo = jnp.array([1.]). Problem is that these are by default float64 if jax_enable_x64 is on, which ends up clashing in jax.lax where type-promotion is not automatic.

It's in a sense blackjax fault for not being more careful, yet the direct fix is to add jax_enable_x64=False in the beginning of the script. As pmwd was imported deep down in my model, it took me a while to figure out why newly created float arrays were still float64!

@eelregit
Copy link
Owner

A good news is that I am working on the conf_refac branch and will drop the auto enabled x64 once it's merged.

@anvvalade
Copy link
Author

excited about the next developments of pmwd ! (and closing btw, now that it's been solved)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants