-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Computing gradient of a nufft2 costs 13x more than the nufft2 alone on GPU #67
Comments
I don't have a GPU to test this on, but I can't reproduce this behavior on my CPU. But first, one issue here is that you'll want to replace the calls to: jax.value_and_grad(map_nufft_over_pupil_cube) with a pre-compiled version: value_and_grad = jax.jit(jax.value_and_grad(map_nufft_over_pupil_cube)) Updated codeimport time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2
jax.config.update("jax_enable_x64", True)
pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128
# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)
# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)
@jax.jit
def map_nufft_over_pupil_cube(pupil):
psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
pupil,
UU.flatten(),
VV.flatten()
).reshape(n_pupils, fov_pix, fov_pix)
return jnp.max((psfs**2).real)
value_and_grad = jax.jit(jax.value_and_grad(map_nufft_over_pupil_cube))
for i in range(n_warmup):
v, g = value_and_grad(pupil_cube)
v.block_until_ready()
v = map_nufft_over_pupil_cube(pupil_cube)
v.block_until_ready()
for i in range(n_trials):
start = time.perf_counter()
v = map_nufft_over_pupil_cube(pupil_cube)
v.block_until_ready()
nufft_time = time.perf_counter() - start
start = time.perf_counter()
v, g = value_and_grad(pupil_cube)
v.block_until_ready()
nufft_with_grad_time = time.perf_counter() - start
print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}") Otherwise, I expect you'll be dominated by tracing/compilation time. See what happens if you do that! For reference, the timing on my machine gives:
So the value and grad costs about 2x the value, which is to be expected. Perhaps @lgarrison can run some benchmarks too if my suggestion doesn't solve your problem! |
Hmm. Well, I restructured the code to move the @jax.jit outside of the value_and_grad call, but I don't see any difference in timings. This is on the GPU in my FI workstation. I can reproduce the ratio you see on the CPU, though. Here's the modified code: import time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2
jax.config.update("jax_enable_x64", True)
pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128
# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)
# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)
def map_nufft_over_pupil_cube(pupil):
psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
pupil,
UU.flatten(),
VV.flatten()
).reshape(n_pupils, fov_pix, fov_pix)
return jnp.max((psfs**2).real)
@jax.jit
def without_grad(pupil_cube):
v = map_nufft_over_pupil_cube(pupil_cube)
return v
@jax.jit
def with_grad(pupil_cube):
v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
return v, g
for i in range(n_warmup):
v, g = with_grad(pupil_cube)
v.block_until_ready()
v = without_grad(pupil_cube)
v.block_until_ready()
for i in range(n_trials):
start = time.perf_counter()
v = without_grad(pupil_cube)
v.block_until_ready()
nufft_time = time.perf_counter() - start
start = time.perf_counter()
v, g = with_grad(pupil_cube)
v.block_until_ready()
nufft_with_grad_time = time.perf_counter() - start
print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}") |
Interesting! Maybe there's some memory transfer or re-ordering issues that are specific to GPU. I'll try to dig a little. |
On my workstation with a crappy GPU I get slower runtimes for everything, but the extra factor for value and grad is 4x. To help diagnose the issues, I started looking at the jaxpr for these operations: without grad
with grad
Some notes about this. As expected, computing the value requires one
and computing the value and grad requires the same
which seems about right. That being said, there's a heck of a lot of reshaping, transposing, etc. that might be leading to issues! Edited with slightly simpler jaxprs. |
I was able to further isolate the issue by adding the following benchmark: from jax_finufft import nufft2, nufft1
@jax.jit
def func(x):
return nufft2(x, UU, VV)
result = jax.block_until_ready(func(pupil_cube))
start = time.perf_counter()
for i in range(n_trials):
jax.block_until_ready(func(pupil_cube))
print(time.perf_counter() - start)
@jax.jit
def func(x):
return nufft1((pupil_npix, pupil_npix), x, UU, VV)
jax.block_until_ready(func(result))
start = time.perf_counter()
for i in range(n_trials):
jax.block_until_ready(func(result))
print(time.perf_counter() - start) The first number printed gives the cost of doing a single |
I've been looking into this, and what it seems to come down to is that the performance parameters just need to be tuned for the Another question is whether cufinufft should be able to choose better tuning parameters automatically. I don't know the answer to this, we could ask in the finufft GH. Here's the script I was using to play around with this, using the native cufinufft interface and not jax-finufft: Scriptimport time
import cufinufft
import cupy as cp
pupil_npix = 128
fov_npix = 128
eps = 1e-6 # 1e-5 is 5x faster with method 2, and 3x slower with method 1
plan_kwargs = {} # dict(gpu_method=1) # method 1 for type 1 is 15-20x faster
pupil = cp.ones((pupil_npix, pupil_npix), dtype=cp.complex128)
# this affects the filling of the UV plane, also affecting performance
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (cp.arange(fov_npix) - fov_npix / 2 + 0.5)
UVs = 2 * cp.pi * (idx_pix / pupil_npix)
UU, VV = cp.meshgrid(UVs, UVs)
UU = UU.flatten()
VV = VV.flatten()
plan2 = cufinufft.Plan(2, pupil.shape, n_trans=1, eps=eps, dtype='complex128', **plan_kwargs)
plan2.setpts(UU, VV)
result2 = plan2.execute(pupil)
t = -time.perf_counter()
for i in range(nrep:=1000):
plan2.execute(pupil)
t += time.perf_counter()
print(f"nufft2d2: {t:.4g} s")
plan1 = cufinufft.Plan(1, pupil.shape, n_trans=1, eps=eps, dtype='complex128', **plan_kwargs)
plan1.setpts(UU, VV)
result1 = plan1.execute(result2)
t = -time.perf_counter()
for i in range(nrep):
plan1.execute(result2)
t += time.perf_counter()
print(f"nufft2d1: {t:.4g} s") |
Interesting - thanks for looking into this, @lgarrison! I'll try to take a first stab at #66 (and think about how that'll play with differentiation, a point that didn't come up over there yet!) later this week. |
Oh, right; if |
I'm using jax-finufft to specify the exact spatial frequencies I want to sample in a 2D Fourier transform. It appears that the FINUFFT is faster than a matrix Fourier transform for my problem size, but computing the gradient is much slower than expected.
Output:
MRE:
The text was updated successfully, but these errors were encountered: