@@ -54,25 +54,26 @@ def run(n, backend, datatype, benchmark_mode):
54
54
if backend == "sharpy" :
55
55
import sharpy as np
56
56
from sharpy import fini , init , sync
57
- from sharpy .numpy import fromfunction as _fromfunction
58
57
59
58
device = os .getenv ("SHARPY_DEVICE" , "" )
60
59
create_full = partial (np .full , device = device )
61
- fromfunction = partial (_fromfunction , device = device )
60
+
61
+ def transpose (a ):
62
+ return np .permute_dims (a , [1 , 0 ])
62
63
63
64
all_axes = [0 , 1 ]
64
65
init (False )
65
66
66
67
elif backend == "numpy" :
67
68
import numpy as np
68
- from numpy import fromfunction
69
69
70
70
if comm is not None :
71
71
assert (
72
72
comm .Get_size () == 1
73
73
), "Numpy backend only supports serial execution."
74
74
75
75
create_full = np .full
76
+ transpose = np .transpose
76
77
77
78
fini = sync = lambda x = None : None
78
79
all_axes = None
@@ -110,34 +111,32 @@ def run(n, backend, datatype, benchmark_mode):
110
111
t_export = 0.02
111
112
t_end = 1.0
112
113
113
- # coordinate arrays
114
- x_t_2d = fromfunction (
115
- lambda i , j : xmin + i * dx + dx / 2 ,
116
- (nx , ny ),
117
- dtype = dtype ,
118
- )
119
- y_t_2d = fromfunction (
120
- lambda i , j : ymin + j * dy + dy / 2 ,
121
- (nx , ny ),
122
- dtype = dtype ,
123
- )
124
- x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
125
- y_u_2d = fromfunction (
126
- lambda i , j : ymin + j * dy + dy / 2 ,
127
- (nx + 1 , ny ),
128
- dtype = dtype ,
129
- )
130
- x_v_2d = fromfunction (
131
- lambda i , j : xmin + i * dx + dx / 2 ,
132
- (nx , ny + 1 ),
133
- dtype = dtype ,
134
- )
135
- y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
114
+ def ind_arr (shape , columns = False ):
115
+ """Construct an (nx, ny) array where each row/col is an arange"""
116
+ nx , ny = shape
117
+ if columns :
118
+ ind = np .arange (0 , nx * ny , 1 , dtype = np .int32 ) % nx
119
+ ind = transpose (np .reshape (ind , (ny , nx )))
120
+ else :
121
+ ind = np .arange (0 , nx * ny , 1 , dtype = np .int32 ) % ny
122
+ ind = np .reshape (ind , (nx , ny ))
123
+ return ind .astype (dtype )
136
124
125
+ # coordinate arrays
137
126
T_shape = (nx , ny )
138
127
U_shape = (nx + 1 , ny )
139
128
V_shape = (nx , ny + 1 )
140
129
F_shape = (nx + 1 , ny + 1 )
130
+ sync ()
131
+ x_t_2d = xmin + ind_arr (T_shape , True ) * dx + dx / 2
132
+ y_t_2d = ymin + ind_arr (T_shape ) * dy + dy / 2
133
+
134
+ x_u_2d = xmin + ind_arr (U_shape , True ) * dx
135
+ y_u_2d = ymin + ind_arr (U_shape ) * dy + dy / 2
136
+
137
+ x_v_2d = xmin + ind_arr (V_shape , True ) * dx + dx / 2
138
+ y_v_2d = ymin + ind_arr (V_shape ) * dy
139
+ sync ()
141
140
142
141
dofs_T = int (numpy .prod (numpy .asarray (T_shape )))
143
142
dofs_U = int (numpy .prod (numpy .asarray (U_shape )))
@@ -205,14 +204,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205
204
bath = 1.0
206
205
return bath * create_full (T_shape , 1.0 , dtype )
207
206
208
- # inital elevation
209
- u0 , v0 , e0 = exact_solution (
210
- 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
211
- )
212
- e [:, :] = e0
213
- u [:, :] = u0
214
- v [:, :] = v0
215
-
216
207
# set bathymetry
217
208
h [:, :] = bathymetry (x_t_2d , y_t_2d , lx , ly )
218
209
# steady state potential energy
@@ -329,6 +320,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
329
320
v [:, 1 :- 1 ] = v [:, 1 :- 1 ] / 3.0 + 2.0 / 3.0 * (v2 [:, 1 :- 1 ] + dt * dvdt )
330
321
e [:, :] = e [:, :] / 3.0 + 2.0 / 3.0 * (e2 [:, :] + dt * dedt )
331
322
323
+ # warm up jit cache
324
+ step (u , v , e , u1 , v1 , e1 , u2 , v2 , e2 )
325
+ sync ()
326
+
327
+ # initial solution
328
+ u0 , v0 , e0 = exact_solution (
329
+ 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
330
+ )
331
+ e [:, :] = e0
332
+ u [:, :] = u0
333
+ v [:, :] = v0
334
+
332
335
t = 0
333
336
i_export = 0
334
337
next_t_export = 0
0 commit comments