@@ -156,7 +156,7 @@ def ind_arr(shape, columns=False):
156
156
q = create_full (F_shape , 0.0 , dtype )
157
157
158
158
# bathymetry
159
- h = create_full (T_shape , 1 .0 , dtype ) # HACK init with 1
159
+ h = create_full (T_shape , 0 .0 , dtype )
160
160
161
161
hu = create_full (U_shape , 0.0 , dtype )
162
162
hv = create_full (V_shape , 0.0 , dtype )
@@ -165,7 +165,7 @@ def ind_arr(shape, columns=False):
165
165
dvdx = create_full (F_shape , 0.0 , dtype )
166
166
167
167
# vector invariant form
168
- H_at_f = create_full (F_shape , 0 .0 , dtype )
168
+ H_at_f = create_full (F_shape , 1 .0 , dtype ) # HACK init with 1
169
169
170
170
# auxiliary variables for RK time integration
171
171
e1 = create_full (T_shape , 0.0 , dtype )
@@ -205,15 +205,14 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205
205
return bath * create_full (T_shape , 1.0 , dtype )
206
206
207
207
# set bathymetry
208
- # h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
208
+ h [:, :] = bathymetry (x_t_2d , y_t_2d , lx , ly ).to_device (device )
209
209
# steady state potential energy
210
- # pe_offset = 0.5 * g * float( np.sum(h**2.0, all_axes)) / nx / ny
211
- pe_offset = 0.5 * g * float (1.0 ) / nx / ny
210
+ h2sum = np .sum (h ** 2.0 , all_axes ). to_device ()
211
+ pe_offset = 0.5 * g * float (np . sum ( h2sum , all_axes ) ) / nx / ny
212
212
213
213
# compute time step
214
214
alpha = 0.5
215
- # h_max = float(np.max(h, all_axes))
216
- h_max = float (1.0 )
215
+ h_max = float (np .max (h , all_axes ).to_device ())
217
216
c = (g * h_max ) ** 0.5
218
217
dt = alpha * dx / c
219
218
dt = t_export / int (math .ceil (t_export / dt ))
@@ -253,10 +252,11 @@ def rhs(u, v, e):
253
252
H_at_f [- 1 , 1 :- 1 ] = 0.5 * (H [- 1 , 1 :] + H [- 1 , :- 1 ])
254
253
H_at_f [1 :- 1 , 0 ] = 0.5 * (H [1 :, 0 ] + H [:- 1 , 0 ])
255
254
H_at_f [1 :- 1 , - 1 ] = 0.5 * (H [1 :, - 1 ] + H [:- 1 , - 1 ])
256
- H_at_f [0 , 0 ] = H [0 , 0 ]
257
- H_at_f [0 , - 1 ] = H [0 , - 1 ]
258
- H_at_f [- 1 , 0 ] = H [- 1 , 0 ]
259
- H_at_f [- 1 , - 1 ] = H [- 1 , - 1 ]
255
+ # NOTE causes gpu.memcpy error, non-identity layout
256
+ # H_at_f[0, 0] = H[0, 0]
257
+ # H_at_f[0, -1] = H[0, -1]
258
+ # H_at_f[-1, 0] = H[-1, 0]
259
+ # H_at_f[-1, -1] = H[-1, -1]
260
260
261
261
# potential vorticity
262
262
dudy [:, 1 :- 1 ] = (u [:, 1 :] - u [:, :- 1 ]) / dy
@@ -346,41 +346,36 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
346
346
t = i * dt
347
347
348
348
if t >= next_t_export - 1e-8 :
349
- if device :
350
- # FIXME gpu.memcpy to host requires identity layout
351
- # FIXME reduction on gpu
352
- elev_max = 0
353
- u_max = 0
354
- q_max = 0
355
- diff_e = 0
356
- diff_v = 0
357
- total_pe = 0
358
- total_ke = 0
359
- else :
360
- _elev_max = np .max (e , all_axes )
361
- _u_max = np .max (u , all_axes )
362
- _q_max = np .max (q , all_axes )
363
- _total_v = np .sum (e + h , all_axes )
364
-
365
- # potential energy
366
- _pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
367
- _total_pe = np .sum (_pe , all_axes )
368
-
369
- # kinetic energy
370
- u2 = u * u
371
- v2 = v * v
372
- u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
373
- v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
374
- _ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
375
- _total_ke = np .sum (_ke , all_axes )
376
-
377
- total_pe = float (_total_pe ) * dx * dy
378
- total_ke = float (_total_ke ) * dx * dy
379
- total_e = total_ke + total_pe
380
- elev_max = float (_elev_max )
381
- u_max = float (_u_max )
382
- q_max = float (_q_max )
383
- total_v = float (_total_v ) * dx * dy
349
+ sync ()
350
+ # NOTE must precompute reduction operands to single field
351
+ H_tmp = e + h
352
+ # potential energy
353
+ _pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
354
+ # kinetic energy
355
+ u2 = u * u
356
+ v2 = v * v
357
+ u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
358
+ v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
359
+ _ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
360
+ sync ()
361
+ _elev_max = np .max (e , all_axes ).to_device ()
362
+ # NOTE max(u) segfaults, shape (n+1, n) too large for tiling
363
+ _u_max = np .max (u [1 :, :], all_axes ).to_device ()
364
+ _q_max = np .max (q [1 :, 1 :], all_axes ).to_device ()
365
+ _total_v = np .sum (H_tmp , all_axes ).to_device ()
366
+ _total_pe = np .sum (_pe , all_axes ).to_device ()
367
+ _total_ke = np .sum (_ke , all_axes ).to_device ()
368
+
369
+ total_pe = float (_total_pe ) * dx * dy
370
+ total_ke = float (_total_ke ) * dx * dy
371
+ total_e = total_ke + total_pe
372
+ elev_max = float (_elev_max )
373
+ u_max = float (_u_max )
374
+ q_max = float (_q_max )
375
+ total_v = float (_total_v ) * dx * dy
376
+
377
+ diff_e = 0
378
+ diff_v = 0
384
379
385
380
if i_export == 0 :
386
381
initial_v = total_v
@@ -415,40 +410,36 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
415
410
duration = time_mod .perf_counter () - tic
416
411
info (f"Duration: { duration :.2f} s" )
417
412
418
- if device :
419
- # FIXME gpu.memcpy to host requires identity layout
420
- # FIXME reduction on gpu
421
- pass
422
- else :
423
- e_exact = exact_solution (
424
- t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
425
- )[2 ]
426
- err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
427
- err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
428
- info (f"L2 error: { err_L2 :7.15e} " )
429
-
430
- if nx < 128 or ny < 128 :
431
- info ("Skipping correctness test due to small problem size." )
432
- elif not benchmark_mode :
433
- tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
434
- assert (
435
- diff_e < tolerance_ene
436
- ), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
437
- if nx == 128 and ny == 128 :
438
- if datatype == "f32" :
439
- assert numpy .allclose (
440
- err_L2 , 4.3127859e-05 , rtol = 1e-5
441
- ), "L2 error does not match"
442
- else :
443
- assert numpy .allclose (
444
- err_L2 , 4.315799035627906e-05
445
- ), "L2 error does not match"
413
+ e_exact = exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d )[
414
+ 2
415
+ ].to_device (device )
416
+ err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
417
+ err2sum = np .sum (err2 , all_axes ).to_device ()
418
+ err_L2 = math .sqrt (float (err2sum ))
419
+ info (f"L2 error: { err_L2 :7.15e} " )
420
+
421
+ if nx < 128 or ny < 128 :
422
+ info ("Skipping correctness test due to small problem size." )
423
+ elif not benchmark_mode :
424
+ tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
425
+ assert (
426
+ diff_e < tolerance_ene
427
+ ), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
428
+ if nx == 128 and ny == 128 :
429
+ if datatype == "f32" :
430
+ assert numpy .allclose (
431
+ err_L2 , 4.3127859e-05 , rtol = 1e-5
432
+ ), "L2 error does not match"
446
433
else :
447
- tolerance_l2 = 1e-4
448
- assert (
449
- err_L2 < tolerance_l2
450
- ), f"L2 error exceeds tolerance: { err_L2 } > { tolerance_l2 } "
451
- info ("SUCCESS" )
434
+ assert numpy .allclose (
435
+ err_L2 , 4.315799035627906e-05
436
+ ), "L2 error does not match"
437
+ else :
438
+ tolerance_l2 = 1e-4
439
+ assert (
440
+ err_L2 < tolerance_l2
441
+ ), f"L2 error exceeds tolerance: { err_L2 } > { tolerance_l2 } "
442
+ info ("SUCCESS" )
452
443
453
444
fini ()
454
445
0 commit comments