Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type mismatch when running x64-bit environment but using float32 inputs #802

Open
adam-hartshorne opened this issue Dec 11, 2024 · 3 comments

Comments

@adam-hartshorne
Copy link

adam-hartshorne commented Dec 11, 2024

Due to the way JAX handles typing, if you want just one function in model to operate on float64 numbers e.g. for numerical precision when taking a Cholesky Inverse, you have to enable x64 mode globally and be careful with casting of types whenever they are created.

Unfortunately, I have found that ProbDiffeq doesn't handle this scenario correctly,

import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl

jax.config.update("jax_enable_x64", True) # TOGGLE THIS

@jax.jit
def vf(y, t): 
    """Evaluate the vector field."""
    yt = 0.5 * t * (1 - y)
    return yt

dtype = jnp.float32

def scalar_float(x):
    return jnp.array(x, dtype=jnp.float64)

def scalar_float32(x):
    return jnp.array(x, dtype=jnp.float32)


if dtype == jnp.float64:
    scalar_type_fn = scalar_float
else:
    scalar_type_fn = scalar_float32

u0 = jnp.asarray([0.1], dtype=dtype)
t0, t1 = scalar_type_fn(0.0), scalar_type_fn(1.0)

impl.select("dense", ode_shape=(1,))
num_derivatives = 1
ibm = ivpsolvers.prior_ibm(num_derivatives=1)
correction = ivpsolvers.correction_ts0(ode_order=1)
strategy = ivpsolvers.strategy_smoother(ibm, correction)
solver = ivpsolvers.solver(strategy)

time_grid = jnp.linspace(0.0, 1.0, 10, dtype=dtype)

tcoeffs = taylor.odejet_padded_scan(
    lambda y: vf(y, time_grid[0]),
    (u0,),
    num=num_derivatives
)

output_scale = scalar_type_fn(1.0)  # or any other value with the same shape

# Within the init stuct returned, the conditional variables return x64 bit when jax.config.jax_enable_x64 = True
init = solver.initial_condition(tcoeffs, output_scale)

sol = ivpsolve.solve_fixed_grid(
            vf,
            init,
            grid=time_grid,
            solver=solver,
        )


print(sol.u.dtype)
@pnkraemer
Copy link
Owner

Hi Adam, thanks for reaching out. Could you maybe elaborate on which behaviour you'd expect and which behaviour you see?

@adam-hartshorne
Copy link
Author

As there is no ability to set the dtype fo the solver, if I provide the initial condition (u0) and time_grid in float32, I expect the solver to only use float32 variables internally. This is true of course if jax.config.update("jax_enable_x64", False), as JAX won't allow to the creation or casting to float64.

However, if you enable x64 in JAX, you will find that the init returns a mix of float32's and float64's and that the internal of your solver also creates a variables that a mix of float32 and float64 types . What happens is the first evaluation of vf, the input is a float32, the return, yt, is float32, but the state attempted to be updated and you get a type mismatch crash.

The reason for the above is that when enable x64 in JAX, any creation of a new array will be float64 unless you type it otherwise. I am going to guess internally with your solver is some of the new variables you create for your solver are getting instantiated at x64.

@pnkraemer
Copy link
Owner

I see, thanks for elaborating! So do I understand correctly that you'd expect that the ODE solver state has the same dtype as the initial condition? That sounds reasonable. Indeed, probdiffeq simply creates some of the arrays without worrying about data types so they become double precision if x64 is enabled.

I'll look into it. But it might be a while until I find the time (likely january). Hope that isn't an issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants