Skip to content

Commit 9506c28

Browse files
authored
Better initialization for wave and shallow water benchmarks (#19)
1 parent 4520d97 commit 9506c28

File tree

2 files changed

+64
-47
lines changed

2 files changed

+64
-47
lines changed

examples/shallow_water.py

+37-34
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,26 @@ def run(n, backend, datatype, benchmark_mode):
5454
if backend == "sharpy":
5555
import sharpy as np
5656
from sharpy import fini, init, sync
57-
from sharpy.numpy import fromfunction as _fromfunction
5857

5958
device = os.getenv("SHARPY_DEVICE", "")
6059
create_full = partial(np.full, device=device)
61-
fromfunction = partial(_fromfunction, device=device)
60+
61+
def transpose(a):
62+
return np.permute_dims(a, [1, 0])
6263

6364
all_axes = [0, 1]
6465
init(False)
6566

6667
elif backend == "numpy":
6768
import numpy as np
68-
from numpy import fromfunction
6969

7070
if comm is not None:
7171
assert (
7272
comm.Get_size() == 1
7373
), "Numpy backend only supports serial execution."
7474

7575
create_full = np.full
76+
transpose = np.transpose
7677

7778
fini = sync = lambda x=None: None
7879
all_axes = None
@@ -110,34 +111,32 @@ def run(n, backend, datatype, benchmark_mode):
110111
t_export = 0.02
111112
t_end = 1.0
112113

113-
# coordinate arrays
114-
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2,
116-
(nx, ny),
117-
dtype=dtype,
118-
)
119-
y_t_2d = fromfunction(
120-
lambda i, j: ymin + j * dy + dy / 2,
121-
(nx, ny),
122-
dtype=dtype,
123-
)
124-
x_u_2d = fromfunction(lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype)
125-
y_u_2d = fromfunction(
126-
lambda i, j: ymin + j * dy + dy / 2,
127-
(nx + 1, ny),
128-
dtype=dtype,
129-
)
130-
x_v_2d = fromfunction(
131-
lambda i, j: xmin + i * dx + dx / 2,
132-
(nx, ny + 1),
133-
dtype=dtype,
134-
)
135-
y_v_2d = fromfunction(lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype)
114+
def ind_arr(shape, columns=False):
115+
"""Construct an (nx, ny) array where each row/col is an arange"""
116+
nx, ny = shape
117+
if columns:
118+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % nx
119+
ind = transpose(np.reshape(ind, (ny, nx)))
120+
else:
121+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % ny
122+
ind = np.reshape(ind, (nx, ny))
123+
return ind.astype(dtype)
136124

125+
# coordinate arrays
137126
T_shape = (nx, ny)
138127
U_shape = (nx + 1, ny)
139128
V_shape = (nx, ny + 1)
140129
F_shape = (nx + 1, ny + 1)
130+
sync()
131+
x_t_2d = xmin + ind_arr(T_shape, True) * dx + dx / 2
132+
y_t_2d = ymin + ind_arr(T_shape) * dy + dy / 2
133+
134+
x_u_2d = xmin + ind_arr(U_shape, True) * dx
135+
y_u_2d = ymin + ind_arr(U_shape) * dy + dy / 2
136+
137+
x_v_2d = xmin + ind_arr(V_shape, True) * dx + dx / 2
138+
y_v_2d = ymin + ind_arr(V_shape) * dy
139+
sync()
141140

142141
dofs_T = int(numpy.prod(numpy.asarray(T_shape)))
143142
dofs_U = int(numpy.prod(numpy.asarray(U_shape)))
@@ -205,14 +204,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205204
bath = 1.0
206205
return bath * create_full(T_shape, 1.0, dtype)
207206

208-
# inital elevation
209-
u0, v0, e0 = exact_solution(
210-
0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
211-
)
212-
e[:, :] = e0
213-
u[:, :] = u0
214-
v[:, :] = v0
215-
216207
# set bathymetry
217208
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
218209
# steady state potential energy
@@ -329,6 +320,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
329320
v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt)
330321
e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt)
331322

323+
# warm up jit cache
324+
step(u, v, e, u1, v1, e1, u2, v2, e2)
325+
sync()
326+
327+
# initial solution
328+
u0, v0, e0 = exact_solution(
329+
0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
330+
)
331+
e[:, :] = e0
332+
u[:, :] = u0
333+
v[:, :] = v0
334+
332335
t = 0
333336
i_export = 0
334337
next_t_export = 0

examples/wave_equation.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,26 @@ def run(n, backend, datatype, benchmark_mode):
5454
if backend == "sharpy":
5555
import sharpy as np
5656
from sharpy import fini, init, sync
57-
from sharpy.numpy import fromfunction as _fromfunction
5857

5958
device = os.getenv("SHARPY_DEVICE", "")
6059
create_full = partial(np.full, device=device)
61-
fromfunction = partial(_fromfunction, device=device)
60+
61+
def transpose(a):
62+
return np.permute_dims(a, [1, 0])
6263

6364
all_axes = [0, 1]
6465
init(False)
6566

6667
elif backend == "numpy":
6768
import numpy as np
68-
from numpy import fromfunction
6969

7070
if comm is not None:
7171
assert (
7272
comm.Get_size() == 1
7373
), "Numpy backend only supports serial execution."
7474

7575
create_full = np.full
76+
transpose = np.transpose
7677

7778
fini = sync = lambda x=None: None
7879
all_axes = None
@@ -110,17 +111,23 @@ def run(n, backend, datatype, benchmark_mode):
110111
t_export = 0.02
111112
t_end = 1.0
112113

113-
# coordinate arrays
114-
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype
116-
)
117-
y_t_2d = fromfunction(
118-
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype
119-
)
114+
def ind_arr(shape, columns=False):
115+
"""Construct an (nx, ny) array where each row/col is an arange"""
116+
nx, ny = shape
117+
if columns:
118+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % nx
119+
ind = transpose(np.reshape(ind, (ny, nx)))
120+
else:
121+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % ny
122+
ind = np.reshape(ind, (nx, ny))
123+
return ind.astype(dtype)
120124

125+
# coordinate arrays
121126
T_shape = (nx, ny)
122127
U_shape = (nx + 1, ny)
123128
V_shape = (nx, ny + 1)
129+
x_t_2d = xmin + ind_arr(T_shape, True) * dx + dx / 2
130+
y_t_2d = ymin + ind_arr(T_shape) * dy + dy / 2
124131

125132
dofs_T = int(numpy.prod(numpy.asarray(T_shape)))
126133
dofs_U = int(numpy.prod(numpy.asarray(U_shape)))
@@ -162,9 +169,6 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
162169
sol_t = numpy.cos(2 * omega * t)
163170
return amp * sol_x * sol_y * sol_t
164171

165-
# inital elevation
166-
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
167-
168172
# compute time step
169173
alpha = 0.5
170174
c = (g * h) ** 0.5
@@ -215,6 +219,16 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
215219
v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt)
216220
e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt)
217221

222+
# warm up jit cache
223+
step(u, v, e, u1, v1, e1, u2, v2, e2)
224+
sync()
225+
226+
# initial solution
227+
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
228+
u[:, :] = create_full(U_shape, 0.0, dtype)
229+
v[:, :] = create_full(V_shape, 0.0, dtype)
230+
sync()
231+
218232
t = 0
219233
i_export = 0
220234
next_t_export = 0

0 commit comments

Comments
 (0)