Skip to content

Commit

Permalink
Fixes for modern JAX: block until CUDA operations complete. Import ja…
Browse files Browse the repository at this point in the history
…x.experimental.
  • Loading branch information
lgarrison committed Oct 24, 2023
1 parent 1ca71ee commit aa1a634
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/kernels.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ void run_nufft(int type, const NufftDescriptor<T>* descriptor, T *x, T *y, T *z,
}
destroy<T>(plan);
delete opts;

// Since JAX 0.4.9, the JAX CUDA stream is non-blocking. Need to wait for results.
cudaStreamSynchronize(stream);
}


Expand Down
1 change: 1 addition & 0 deletions tests/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import product

import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import pytest
Expand Down

0 comments on commit aa1a634

Please sign in to comment.