From 545db67947da03051a168244cd82b649d463fc83 Mon Sep 17 00:00:00 2001 From: Steve Penny Date: Fri, 24 May 2024 18:39:09 -0600 Subject: [PATCH] fixing 64-bit option in sqgturb --- dabench/data/sqgturb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dabench/data/sqgturb.py b/dabench/data/sqgturb.py index aaa659a..f16f4a7 100644 --- a/dabench/data/sqgturb.py +++ b/dabench/data/sqgturb.py @@ -36,7 +36,6 @@ import jax import jax.numpy as jnp from jax.numpy.fft import rfft2, irfft2 -from jax import config from functools import partial from importlib import resources @@ -44,7 +43,10 @@ from dabench import _suppl_data # Set to enable 64bit floats in Jax -config.update('jax_enable_x64', True) +# Following: +# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision +import os +os.environ["JAX_ENABLE_X64"] = 'True' class SQGTurb(_data.Data):