Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computing gradient of a nufft2 costs 13x more than the nufft2 alone on GPU #67

Closed
joseph-long opened this issue Mar 7, 2024 · 8 comments · Fixed by #68
Closed

Computing gradient of a nufft2 costs 13x more than the nufft2 alone on GPU #67

joseph-long opened this issue Mar 7, 2024 · 8 comments · Fixed by #68

Comments

@joseph-long
Copy link

I'm using jax-finufft to specify the exact spatial frequencies I want to sample in a 2D Fourier transform. It appears that the FINUFFT is faster than a matrix Fourier transform for my problem size, but computing the gradient is much slower than expected.

Output:

$ python tiny_jax_finufft_comparison.py 
nufft_time=0.005215979181230068 nufft_with_grad_time=0.06915004085749388 nufft_with_grad_time/nufft_time=13.257346023606337
nufft_time=0.004781747702509165 nufft_with_grad_time=0.06479286774992943 nufft_with_grad_time/nufft_time=13.550038977575113
nufft_time=0.004827893804758787 nufft_with_grad_time=0.06621506763622165 nufft_with_grad_time/nufft_time=13.715104414880539
nufft_time=0.004811902996152639 nufft_with_grad_time=0.06634514313191175 nufft_with_grad_time/nufft_time=13.787714171494741
nufft_time=0.004738332703709602 nufft_with_grad_time=0.06613345304504037 nufft_with_grad_time/nufft_time=13.957114702660078
nufft_time=0.005355316214263439 nufft_with_grad_time=0.06647736439481378 nufft_with_grad_time/nufft_time=12.413340638552182
nufft_time=0.004664996173232794 nufft_with_grad_time=0.06625903397798538 nufft_with_grad_time/nufft_time=14.203448731249132
nufft_time=0.004705913830548525 nufft_with_grad_time=0.06623794976621866 nufft_with_grad_time/nufft_time=14.075470174620243
nufft_time=0.004704104270786047 nufft_with_grad_time=0.06453194282948971 nufft_with_grad_time/nufft_time=13.718221177675243
nufft_time=0.004726390819996595 nufft_with_grad_time=0.06554532889276743 nufft_with_grad_time/nufft_time=13.867945201538506

MRE:

import time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2

jax.config.update("jax_enable_x64", True)

pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128

# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)

# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)

@jax.jit
def map_nufft_over_pupil_cube(pupil):
    psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
        pupil,
        UU.flatten(), 
        VV.flatten()
    ).reshape(n_pupils, fov_pix, fov_pix)
    return jnp.max((psfs**2).real)

for i in range(n_warmup):
    v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
    v.block_until_ready()
    v = map_nufft_over_pupil_cube(pupil_cube)
    v.block_until_ready()

for i in range(n_trials):
    start = time.perf_counter()
    v = map_nufft_over_pupil_cube(pupil_cube)
    v.block_until_ready()
    nufft_time = time.perf_counter() - start

    start = time.perf_counter()
    v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
    v.block_until_ready()
    nufft_with_grad_time = time.perf_counter() - start

    print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}")
@dfm
Copy link
Collaborator

dfm commented Mar 7, 2024

I don't have a GPU to test this on, but I can't reproduce this behavior on my CPU.

But first, one issue here is that you'll want to replace the calls to:

jax.value_and_grad(map_nufft_over_pupil_cube)

with a pre-compiled version:

value_and_grad = jax.jit(jax.value_and_grad(map_nufft_over_pupil_cube))
Updated code
import time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2

jax.config.update("jax_enable_x64", True)

pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128

# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)

# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)

@jax.jit
def map_nufft_over_pupil_cube(pupil):
    psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
        pupil,
        UU.flatten(), 
        VV.flatten()
    ).reshape(n_pupils, fov_pix, fov_pix)
    return jnp.max((psfs**2).real)

value_and_grad = jax.jit(jax.value_and_grad(map_nufft_over_pupil_cube))

for i in range(n_warmup):
    v, g = value_and_grad(pupil_cube)
    v.block_until_ready()
    v = map_nufft_over_pupil_cube(pupil_cube)
    v.block_until_ready()

for i in range(n_trials):
    start = time.perf_counter()
    v = map_nufft_over_pupil_cube(pupil_cube)
    v.block_until_ready()
    nufft_time = time.perf_counter() - start

    start = time.perf_counter()
    v, g = value_and_grad(pupil_cube)
    v.block_until_ready()
    nufft_with_grad_time = time.perf_counter() - start

    print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}")

Otherwise, I expect you'll be dominated by tracing/compilation time. See what happens if you do that!

For reference, the timing on my machine gives:

398469537 nufft_with_grad_time=0.24668104195734486 nufft_with_grad_time/nufft_time=1.78022185899258
nufft_time=0.13608274998841807 nufft_with_grad_time=0.2466957090073265 nufft_with_grad_time/nufft_time=1.8128360062412219
nufft_time=0.13956649997271597 nufft_with_grad_time=0.24427195801399648 nufft_with_grad_time/nufft_time=1.750219128958235
nufft_time=0.13765895902179182 nufft_with_grad_time=0.24193362501682714 nufft_with_grad_time/nufft_time=1.757485504292738
nufft_time=0.13943391700740904 nufft_with_grad_time=0.24902837502304465 nufft_with_grad_time/nufft_time=1.785995691491706
nufft_time=0.13630358397495002 nufft_with_grad_time=0.2867940000141971 nufft_with_grad_time/nufft_time=2.104082604804466
nufft_time=0.14015749999089167 nufft_with_grad_time=0.2769365839776583 nufft_with_grad_time/nufft_time=1.9758955745904103
nufft_time=0.13771983300102875 nufft_with_grad_time=0.2544636250240728 nufft_with_grad_time/nufft_time=1.8476904849439661
nufft_time=0.15944166702684015 nufft_with_grad_time=0.30837995803449303 nufft_with_grad_time/nufft_time=1.9341240203075702
nufft_time=0.16091104096267372 nufft_with_grad_time=0.2543251250172034 nufft_with_grad_time/nufft_time=1.5805324699639398

So the value and grad costs about 2x the value, which is to be expected.

Perhaps @lgarrison can run some benchmarks too if my suggestion doesn't solve your problem!

@joseph-long
Copy link
Author

Hmm. Well, I restructured the code to move the @jax.jit outside of the value_and_grad call, but I don't see any difference in timings. This is on the GPU in my FI workstation. I can reproduce the ratio you see on the CPU, though.

Here's the modified code:

import time
import jax
import numpy as np
import jax.numpy as jnp
from jax_finufft import nufft2

jax.config.update("jax_enable_x64", True)

pupil_npix = 513
n_warmup = 2
n_trials = 10
n_pupils = 50
pupil_npix = 128
fov_pix = 128

# construct pupil
xs = np.linspace(-0.5, 0.5, num=pupil_npix)
xx, yy = np.meshgrid(xs, xs)
rr = np.hypot(xx, yy)
pupil = (rr < 0.5).astype(np.complex64)
pupil_cube = jnp.repeat(pupil[jnp.newaxis, :, :], n_pupils, axis=0).astype(jnp.complex128)

# construct spatial frequency evaluation points
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (jnp.arange(fov_pix) - fov_pix / 2 + 0.5)
UVs = 2 * jnp.pi * (idx_pix / pupil_npix)
UU, VV = jnp.meshgrid(UVs, UVs)

def map_nufft_over_pupil_cube(pupil):
    psfs = jax.vmap(nufft2, in_axes=[0, None, None])(
        pupil,
        UU.flatten(), 
        VV.flatten()
    ).reshape(n_pupils, fov_pix, fov_pix)
    return jnp.max((psfs**2).real)

@jax.jit
def without_grad(pupil_cube):
    v = map_nufft_over_pupil_cube(pupil_cube)
    return v

@jax.jit
def with_grad(pupil_cube):
    v, g = jax.value_and_grad(map_nufft_over_pupil_cube)(pupil_cube)
    return v, g

for i in range(n_warmup):
    v, g = with_grad(pupil_cube)
    v.block_until_ready()
    v = without_grad(pupil_cube)
    v.block_until_ready()

for i in range(n_trials):
    start = time.perf_counter()
    v = without_grad(pupil_cube)
    v.block_until_ready()
    nufft_time = time.perf_counter() - start

    start = time.perf_counter()
    v, g = with_grad(pupil_cube)
    v.block_until_ready()
    nufft_with_grad_time = time.perf_counter() - start

    print(f"{nufft_time=} {nufft_with_grad_time=} {nufft_with_grad_time/nufft_time=}")

@joseph-long joseph-long changed the title Computing gradient of a nufft2 costs 13x more than the nufft2 alone Computing gradient of a nufft2 costs 13x more than the nufft2 alone on GPU Mar 7, 2024
@dfm
Copy link
Collaborator

dfm commented Mar 7, 2024

Interesting! Maybe there's some memory transfer or re-ordering issues that are specific to GPU. I'll try to dig a little.

@dfm
Copy link
Collaborator

dfm commented Mar 7, 2024

On my workstation with a crappy GPU I get slower runtimes for everything, but the extra factor for value and grad is 4x.

To help diagnose the issues, I started looking at the jaxpr for these operations:

without grad
{ lambda ; a:c128[50,128,128]. let
    b:f64[] = pjit[
      name=without_grad
      jaxpr={ lambda c:f64[16384] d:f64[16384]; e:c128[50,128,128]. let
          f:c128[50,16384] = pjit[
            name=nufft2
            jaxpr={ lambda ; g:c128[50,128,128] h:f64[16384] i:f64[16384]. let
                j:f64[1,16384] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(1, 16384)
                ] h
                k:f64[1,16384] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(1, 16384)
                ] i
                l:c128[1,50,128,128] = reshape[
                  dimensions=None
                  new_sizes=(1, 50, 128, 128)
                ] g
                m:c128[1,50,16384] = nufft2[
                  eps=1e-06
                  iflag=-1
                  output_shape=None
                ] l j k
                n:c128[50,16384] = reshape[
                  dimensions=None
                  new_sizes=(50, 16384)
                ] m
              in (n,) }
          ] e c d
          o:c128[50,16384] = integer_pow[y=2] f
          p:f64[50,16384] = real o
          q:f64[] = reduce_max[axes=(0, 1)] p
        in (q,) }
    ] a
  in (b,) }
with grad
{ lambda ; a:c128[50,128,128]. let
    b:f64[] c:c128[50,128,128] = pjit[
      name=with_grad
      jaxpr={ lambda d:f64[16384] e:f64[16384]; f:c128[50,128,128]. let
          g:c128[50,16384] h:f64[1,16384] i:f64[1,16384] = pjit[
            name=nufft2
            jaxpr={ lambda ; j:c128[50,128,128] k:f64[16384] l:f64[16384]. let
                m:f64[1,16384] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(1, 16384)
                ] k
                n:f64[1,16384] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(1, 16384)
                ] l
                o:c128[1,50,128,128] = reshape[
                  dimensions=None
                  new_sizes=(1, 50, 128, 128)
                ] j
                p:c128[1,50,16384] = nufft2[
                  eps=1e-06
                  iflag=-1
                  output_shape=None
                ] o m n
                q:c128[50,16384] = reshape[
                  dimensions=None
                  new_sizes=(50, 16384)
                ] p
              in (q, m, n) }
          ] f d e
          r:c128[50,16384] = integer_pow[y=2] g
          s:c128[50,16384] = integer_pow[y=1] g
          t:c128[50,16384] = mul (2+0j) s
          u:f64[50,16384] = real r
          v:f64[] = reduce_max[axes=(0, 1)] u
          w:f64[1,1] = reshape[dimensions=None new_sizes=(1, 1)] v
          x:bool[50,16384] = eq u w
          y:f64[50,16384] = convert_element_type[
            new_dtype=float64
            weak_type=False
          ] x
          z:f64[] = reduce_sum[axes=(0, 1)] y
          ba:f64[] = div 1.0 z
          bb:f64[50,16384] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(50, 16384)
          ] ba
          bc:f64[50,16384] = mul bb y
          bd:c128[50,16384] = complex bc 0.0
          be:c128[50,16384] = mul bd t
          bf:c128[50,128,128] = pjit[
            name=nufft2
            jaxpr={ lambda ; bg:f64[1,16384] bh:f64[1,16384] bi:c128[50,16384]. let
                bj:c128[1,50,16384] = reshape[
                  dimensions=None
                  new_sizes=(1, 50, 16384)
                ] bi
                bk:c128[1,50,128,128] = pjit[
                  name=nufft1
                  jaxpr={ lambda ; bl:c128[1,50,16384] bm:f64[1,16384] bn:f64[1,16384]. let
                      bo:f64[1,1,16384] = broadcast_in_dim[
                        broadcast_dimensions=(0, 2)
                        shape=(1, 1, 16384)
                      ] bm
                      bp:f64[1,1,16384] = broadcast_in_dim[
                        broadcast_dimensions=(0, 2)
                        shape=(1, 1, 16384)
                      ] bn
                      bq:f64[1,16384] = reshape[
                        dimensions=None
                        new_sizes=(1, 16384)
                      ] bo
                      br:f64[1,16384] = reshape[
                        dimensions=None
                        new_sizes=(1, 16384)
                      ] bp
                      bs:c128[1,50,128,128] = nufft1[
                        eps=1e-06
                        iflag=-1
                        output_shape=[128 128]
                      ] bl bq br
                    in (bs,) }
                ] bj bg bh
                bt:c128[50,128,128] = reshape[
                  dimensions=None
                  new_sizes=(50, 128, 128)
                ] bk
              in (bt,) }
          ] h i be
        in (v, bf) }
    ] a
  in (b, c) }

Some notes about this. As expected, computing the value requires one nufft2:

                      bc:c128[1,50,16384] = nufft2[
                        eps=1e-06
                        iflag=-1
                        output_shape=None
                      ] z ba bb

and computing the value and grad requires the same nufft2, and a nufft1 with the following signature:

                            cu:c128[1,50,128,128] = nufft1[
                              eps=1e-06
                              iflag=-1
                              output_shape=[128 128]
                            ] cn cs ct

which seems about right.

That being said, there's a heck of a lot of reshaping, transposing, etc. that might be leading to issues!

Edited with slightly simpler jaxprs.

@dfm
Copy link
Collaborator

dfm commented Mar 7, 2024

I was able to further isolate the issue by adding the following benchmark:

from jax_finufft import nufft2, nufft1

@jax.jit
def func(x):
    return nufft2(x, UU, VV)

result = jax.block_until_ready(func(pupil_cube))
start = time.perf_counter()
for i in range(n_trials):
    jax.block_until_ready(func(pupil_cube))
print(time.perf_counter() - start)

@jax.jit
def func(x):
    return nufft1((pupil_npix, pupil_npix), x, UU, VV)

jax.block_until_ready(func(result))
start = time.perf_counter()
for i in range(n_trials):
    jax.block_until_ready(func(result))
print(time.perf_counter() - start)

The first number printed gives the cost of doing a single nufft2, and the second gives the cost of computing the nufft1 which is required to backprop through the initial nufft2. On some GPUs the nufft1 seems to be a factor of ~20 slower than the nufft2. We're not sure if that should be expected!

@lgarrison
Copy link
Member

I've been looking into this, and what it seems to come down to is that the performance parameters just need to be tuned for the nufft1. Setting gpu_method=1 (NU-points driven instead of shared-memory driven) makes almost the whole discrepancy go away (from 20x to 1.3x). It's possible that other tuning parameters should be changed, too, but I think this means that we should implement #66 sooner rather than later!

Another question is whether cufinufft should be able to choose better tuning parameters automatically. I don't know the answer to this, we could ask in the finufft GH.

Here's the script I was using to play around with this, using the native cufinufft interface and not jax-finufft:

Script
import time

import cufinufft
import cupy as cp

pupil_npix = 128
fov_npix = 128
eps = 1e-6  # 1e-5 is 5x faster with method 2, and 3x slower with method 1
plan_kwargs = {}  # dict(gpu_method=1)  # method 1 for type 1 is 15-20x faster

pupil = cp.ones((pupil_npix, pupil_npix), dtype=cp.complex128)

# this affects the filling of the UV plane, also affecting performance
delta_lamd_per_pix = 0.5
idx_pix = delta_lamd_per_pix * (cp.arange(fov_npix) - fov_npix / 2 + 0.5)
UVs = 2 * cp.pi * (idx_pix / pupil_npix)

UU, VV = cp.meshgrid(UVs, UVs)
UU = UU.flatten()
VV = VV.flatten()

plan2 = cufinufft.Plan(2, pupil.shape, n_trans=1, eps=eps, dtype='complex128', **plan_kwargs)
plan2.setpts(UU, VV)
result2 = plan2.execute(pupil)

t = -time.perf_counter()
for i in range(nrep:=1000):
    plan2.execute(pupil)
t += time.perf_counter()
print(f"nufft2d2: {t:.4g} s")

plan1 = cufinufft.Plan(1, pupil.shape, n_trans=1, eps=eps, dtype='complex128', **plan_kwargs)
plan1.setpts(UU, VV)
result1 = plan1.execute(result2)

t = -time.perf_counter()
for i in range(nrep):
    plan1.execute(result2)
t += time.perf_counter()
print(f"nufft2d1: {t:.4g} s")

@dfm
Copy link
Collaborator

dfm commented Mar 11, 2024

Interesting - thanks for looking into this, @lgarrison! I'll try to take a first stab at #66 (and think about how that'll play with differentiation, a point that didn't come up over there yet!) later this week.

@lgarrison
Copy link
Member

and think about how that'll play with differentiation

Oh, right; if nufft2 prefers gpu_method=2 (which it seems to), and nufft1 prefers gpu_method=1, then we need a place to set the parameters for each method separately. Tricky!

@dfm dfm linked a pull request Mar 15, 2024 that will close this issue
@dfm dfm closed this as completed in #68 Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants