diff --git a/dabench/data/sqgturb.py b/dabench/data/sqgturb.py index aaa659a..f16f4a7 100644 --- a/dabench/data/sqgturb.py +++ b/dabench/data/sqgturb.py @@ -36,7 +36,6 @@ import jax import jax.numpy as jnp from jax.numpy.fft import rfft2, irfft2 -from jax import config from functools import partial from importlib import resources @@ -44,7 +43,10 @@ from dabench import _suppl_data # Set to enable 64bit floats in Jax -config.update('jax_enable_x64', True) +# Following: +# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision +import os +os.environ["JAX_ENABLE_X64"] = 'True' class SQGTurb(_data.Data):