diff --git a/lib/kernels.cc.cu b/lib/kernels.cc.cu index fe68d7a..cc90eb1 100644 --- a/lib/kernels.cc.cu +++ b/lib/kernels.cc.cu @@ -32,6 +32,9 @@ void run_nufft(int type, const NufftDescriptor* descriptor, T *x, T *y, T *z, } destroy(plan); delete opts; + + // Since JAX 0.4.9, the JAX CUDA stream is non-blocking. Need to wait for results. + cudaStreamSynchronize(stream); } diff --git a/tests/ops_test.py b/tests/ops_test.py index 96ac28f..c2583bb 100644 --- a/tests/ops_test.py +++ b/tests/ops_test.py @@ -2,6 +2,7 @@ from itertools import product import jax +import jax.experimental import jax.numpy as jnp import numpy as np import pytest