You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Due to the way JAX handles typing, if you want just one function in model to operate on float64 numbers e.g. for numerical precision when taking a Cholesky Inverse, you have to enable x64 mode globally and be careful with casting of types whenever they are created.
Unfortunately, I have found that ProbDiffeq doesn't handle this scenario correctly,
import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
jax.config.update("jax_enable_x64", True) # TOGGLE THIS
@jax.jit
def vf(y, t):
"""Evaluate the vector field."""
yt = 0.5 * t * (1 - y)
return yt
dtype = jnp.float32
def scalar_float(x):
return jnp.array(x, dtype=jnp.float64)
def scalar_float32(x):
return jnp.array(x, dtype=jnp.float32)
if dtype == jnp.float64:
scalar_type_fn = scalar_float
else:
scalar_type_fn = scalar_float32
u0 = jnp.asarray([0.1], dtype=dtype)
t0, t1 = scalar_type_fn(0.0), scalar_type_fn(1.0)
impl.select("dense", ode_shape=(1,))
num_derivatives = 1
ibm = ivpsolvers.prior_ibm(num_derivatives=1)
correction = ivpsolvers.correction_ts0(ode_order=1)
strategy = ivpsolvers.strategy_smoother(ibm, correction)
solver = ivpsolvers.solver(strategy)
time_grid = jnp.linspace(0.0, 1.0, 10, dtype=dtype)
tcoeffs = taylor.odejet_padded_scan(
lambda y: vf(y, time_grid[0]),
(u0,),
num=num_derivatives
)
output_scale = scalar_type_fn(1.0) # or any other value with the same shape
# Within the init stuct returned, the conditional variables return x64 bit when jax.config.jax_enable_x64 = True
init = solver.initial_condition(tcoeffs, output_scale)
sol = ivpsolve.solve_fixed_grid(
vf,
init,
grid=time_grid,
solver=solver,
)
print(sol.u.dtype)
The text was updated successfully, but these errors were encountered:
As there is no ability to set the dtype fo the solver, if I provide the initial condition (u0) and time_grid in float32, I expect the solver to only use float32 variables internally. This is true of course if jax.config.update("jax_enable_x64", False), as JAX won't allow to the creation or casting to float64.
However, if you enable x64 in JAX, you will find that the init returns a mix of float32's and float64's and that the internal of your solver also creates a variables that a mix of float32 and float64 types . What happens is the first evaluation of vf, the input is a float32, the return, yt, is float32, but the state attempted to be updated and you get a type mismatch crash.
The reason for the above is that when enable x64 in JAX, any creation of a new array will be float64 unless you type it otherwise. I am going to guess internally with your solver is some of the new variables you create for your solver are getting instantiated at x64.
I see, thanks for elaborating! So do I understand correctly that you'd expect that the ODE solver state has the same dtype as the initial condition? That sounds reasonable. Indeed, probdiffeq simply creates some of the arrays without worrying about data types so they become double precision if x64 is enabled.
I'll look into it. But it might be a while until I find the time (likely january). Hope that isn't an issue.
Due to the way JAX handles typing, if you want just one function in model to operate on float64 numbers e.g. for numerical precision when taking a Cholesky Inverse, you have to enable x64 mode globally and be careful with casting of types whenever they are created.
Unfortunately, I have found that ProbDiffeq doesn't handle this scenario correctly,
The text was updated successfully, but these errors were encountered: