Skip to content

Commit 9987123

Browse files
committed
shallow water: reductions and L2 error comp runs on GPU
1 parent b1a7c3f commit 9987123

File tree

1 file changed

+70
-79
lines changed

1 file changed

+70
-79
lines changed

examples/shallow_water.py

+70-79
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def ind_arr(shape, columns=False):
156156
q = create_full(F_shape, 0.0, dtype)
157157

158158
# bathymetry
159-
h = create_full(T_shape, 1.0, dtype) # HACK init with 1
159+
h = create_full(T_shape, 0.0, dtype)
160160

161161
hu = create_full(U_shape, 0.0, dtype)
162162
hv = create_full(V_shape, 0.0, dtype)
@@ -165,7 +165,7 @@ def ind_arr(shape, columns=False):
165165
dvdx = create_full(F_shape, 0.0, dtype)
166166

167167
# vector invariant form
168-
H_at_f = create_full(F_shape, 0.0, dtype)
168+
H_at_f = create_full(F_shape, 1.0, dtype) # HACK init with 1
169169

170170
# auxiliary variables for RK time integration
171171
e1 = create_full(T_shape, 0.0, dtype)
@@ -205,15 +205,14 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205205
return bath * create_full(T_shape, 1.0, dtype)
206206

207207
# set bathymetry
208-
# h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
208+
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
209209
# steady state potential energy
210-
# pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
211-
pe_offset = 0.5 * g * float(1.0) / nx / ny
210+
h2sum = np.sum(h**2.0, all_axes).to_device()
211+
pe_offset = 0.5 * g * float(np.sum(h2sum, all_axes)) / nx / ny
212212

213213
# compute time step
214214
alpha = 0.5
215-
# h_max = float(np.max(h, all_axes))
216-
h_max = float(1.0)
215+
h_max = float(np.max(h, all_axes).to_device())
217216
c = (g * h_max) ** 0.5
218217
dt = alpha * dx / c
219218
dt = t_export / int(math.ceil(t_export / dt))
@@ -253,10 +252,11 @@ def rhs(u, v, e):
253252
H_at_f[-1, 1:-1] = 0.5 * (H[-1, 1:] + H[-1, :-1])
254253
H_at_f[1:-1, 0] = 0.5 * (H[1:, 0] + H[:-1, 0])
255254
H_at_f[1:-1, -1] = 0.5 * (H[1:, -1] + H[:-1, -1])
256-
H_at_f[0, 0] = H[0, 0]
257-
H_at_f[0, -1] = H[0, -1]
258-
H_at_f[-1, 0] = H[-1, 0]
259-
H_at_f[-1, -1] = H[-1, -1]
255+
# NOTE causes gpu.memcpy error, non-identity layout
256+
# H_at_f[0, 0] = H[0, 0]
257+
# H_at_f[0, -1] = H[0, -1]
258+
# H_at_f[-1, 0] = H[-1, 0]
259+
# H_at_f[-1, -1] = H[-1, -1]
260260

261261
# potential vorticity
262262
dudy[:, 1:-1] = (u[:, 1:] - u[:, :-1]) / dy
@@ -346,41 +346,36 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
346346
t = i * dt
347347

348348
if t >= next_t_export - 1e-8:
349-
if device:
350-
# FIXME gpu.memcpy to host requires identity layout
351-
# FIXME reduction on gpu
352-
elev_max = 0
353-
u_max = 0
354-
q_max = 0
355-
diff_e = 0
356-
diff_v = 0
357-
total_pe = 0
358-
total_ke = 0
359-
else:
360-
_elev_max = np.max(e, all_axes)
361-
_u_max = np.max(u, all_axes)
362-
_q_max = np.max(q, all_axes)
363-
_total_v = np.sum(e + h, all_axes)
364-
365-
# potential energy
366-
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
367-
_total_pe = np.sum(_pe, all_axes)
368-
369-
# kinetic energy
370-
u2 = u * u
371-
v2 = v * v
372-
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
373-
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
374-
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
375-
_total_ke = np.sum(_ke, all_axes)
376-
377-
total_pe = float(_total_pe) * dx * dy
378-
total_ke = float(_total_ke) * dx * dy
379-
total_e = total_ke + total_pe
380-
elev_max = float(_elev_max)
381-
u_max = float(_u_max)
382-
q_max = float(_q_max)
383-
total_v = float(_total_v) * dx * dy
349+
sync()
350+
# NOTE must precompute reduction operands to single field
351+
H_tmp = e + h
352+
# potential energy
353+
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
354+
# kinetic energy
355+
u2 = u * u
356+
v2 = v * v
357+
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
358+
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
359+
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
360+
sync()
361+
_elev_max = np.max(e, all_axes).to_device()
362+
# NOTE max(u) segfaults, shape (n+1, n) too large for tiling
363+
_u_max = np.max(u[1:, :], all_axes).to_device()
364+
_q_max = np.max(q[1:, 1:], all_axes).to_device()
365+
_total_v = np.sum(H_tmp, all_axes).to_device()
366+
_total_pe = np.sum(_pe, all_axes).to_device()
367+
_total_ke = np.sum(_ke, all_axes).to_device()
368+
369+
total_pe = float(_total_pe) * dx * dy
370+
total_ke = float(_total_ke) * dx * dy
371+
total_e = total_ke + total_pe
372+
elev_max = float(_elev_max)
373+
u_max = float(_u_max)
374+
q_max = float(_q_max)
375+
total_v = float(_total_v) * dx * dy
376+
377+
diff_e = 0
378+
diff_v = 0
384379

385380
if i_export == 0:
386381
initial_v = total_v
@@ -415,40 +410,36 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
415410
duration = time_mod.perf_counter() - tic
416411
info(f"Duration: {duration:.2f} s")
417412

418-
if device:
419-
# FIXME gpu.memcpy to host requires identity layout
420-
# FIXME reduction on gpu
421-
pass
422-
else:
423-
e_exact = exact_solution(
424-
t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
425-
)[2]
426-
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
427-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
428-
info(f"L2 error: {err_L2:7.15e}")
429-
430-
if nx < 128 or ny < 128:
431-
info("Skipping correctness test due to small problem size.")
432-
elif not benchmark_mode:
433-
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
434-
assert (
435-
diff_e < tolerance_ene
436-
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
437-
if nx == 128 and ny == 128:
438-
if datatype == "f32":
439-
assert numpy.allclose(
440-
err_L2, 4.3127859e-05, rtol=1e-5
441-
), "L2 error does not match"
442-
else:
443-
assert numpy.allclose(
444-
err_L2, 4.315799035627906e-05
445-
), "L2 error does not match"
413+
e_exact = exact_solution(t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d)[
414+
2
415+
].to_device(device)
416+
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
417+
err2sum = np.sum(err2, all_axes).to_device()
418+
err_L2 = math.sqrt(float(err2sum))
419+
info(f"L2 error: {err_L2:7.15e}")
420+
421+
if nx < 128 or ny < 128:
422+
info("Skipping correctness test due to small problem size.")
423+
elif not benchmark_mode:
424+
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
425+
assert (
426+
diff_e < tolerance_ene
427+
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
428+
if nx == 128 and ny == 128:
429+
if datatype == "f32":
430+
assert numpy.allclose(
431+
err_L2, 4.3127859e-05, rtol=1e-5
432+
), "L2 error does not match"
446433
else:
447-
tolerance_l2 = 1e-4
448-
assert (
449-
err_L2 < tolerance_l2
450-
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
451-
info("SUCCESS")
434+
assert numpy.allclose(
435+
err_L2, 4.315799035627906e-05
436+
), "L2 error does not match"
437+
else:
438+
tolerance_l2 = 1e-4
439+
assert (
440+
err_L2 < tolerance_l2
441+
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
442+
info("SUCCESS")
452443

453444
fini()
454445

0 commit comments

Comments
 (0)