Skip to content

Commit 186f13e

Browse files
committed
Consistent serial and threading perf
1 parent 403bb08 commit 186f13e

File tree

2 files changed

+59
-67
lines changed

2 files changed

+59
-67
lines changed

src/pusher.jl

+57-65
Original file line numberDiff line numberDiff line change
@@ -124,90 +124,82 @@ function solve(prob::TraceProblem, ensemblealg::BasicEnsembleAlgorithm=EnsembleS
124124
trajectories::Int=1, savestepinterval::Int=1,
125125
isoutofdomain::Function=ODE_DEFAULT_ISOUTOFDOMAIN)
126126

127-
sols = __solve(ensemblealg, prob, trajectories, savestepinterval, isoutofdomain)
127+
sols = _solve(ensemblealg, prob, trajectories, savestepinterval, isoutofdomain)
128128
end
129129

130-
function __solve(::EnsembleSerial, prob, trajectories, savestepinterval, isoutofdomain)
131-
sols = Vector{TraceSolution}(undef, trajectories)
132-
# prepare advancing
133-
(; tspan, dt, p) = prob
134-
ttotal = tspan[2] - tspan[1]
135-
nt = round(Int, ttotal / dt)
136-
nout = nt ÷ savestepinterval + 1
137-
traj = zeros(eltype(prob.u0), 6, nout)
138-
139-
paramBoris = BorisMethod()
140-
xv = similar(prob.u0)
141-
nout = nt ÷ savestepinterval + 1
130+
function _solve(::EnsembleSerial, prob, trajectories, savestepinterval, isoutofdomain)
131+
sols, ttotal, nt, nout = _prepare(prob, trajectories, savestepinterval)
142132

143-
for i in 1:trajectories
144-
_boris!(sols, prob, i, xv, paramBoris, p, dt, savestepinterval, nout, ttotal, nt,
145-
traj, tspan, isoutofdomain)
146-
end
133+
_boris!(sols, prob, 1:trajectories, savestepinterval, ttotal, nt, nout, isoutofdomain)
147134

148135
sols
149136
end
150137

151-
function __solve(::EnsembleThreads, prob, trajectories, savestepinterval, isoutofdomain)
152-
sols = Vector{TraceSolution}(undef, trajectories)
153-
# prepare advancing
154-
(; tspan, dt, p) = prob
155-
ttotal = tspan[2] - tspan[1]
156-
nt = round(Int, ttotal / dt)
157-
nout = nt ÷ savestepinterval + 1
138+
function _solve(::EnsembleThreads, prob, trajectories, savestepinterval, isoutofdomain)
139+
sols, ttotal, nt, nout = _prepare(prob, trajectories, savestepinterval)
158140

159141
nchunks = Threads.nthreads()
160142
Threads.@threads for (irange, ichunk) in chunks(1:trajectories, nchunks)
161-
paramBoris = BorisMethod()
162-
xv = similar(prob.u0)
163-
traj = zeros(eltype(prob.u0), 6, nout)
164-
165-
for i in irange
166-
_boris!(sols, prob, i, xv, paramBoris, p, dt, savestepinterval, nout, ttotal, nt,
167-
traj, tspan, isoutofdomain)
168-
end
143+
_boris!(sols, prob, irange, savestepinterval, ttotal, nt, nout, isoutofdomain)
169144
end
170145

171146
sols
172147
end
173148

174-
function _boris!(sols, prob, i,
175-
xv, paramBoris, p, dt, savestepinterval, nout, ttotal, nt, traj, tspan, isoutofdomain)
176-
# set initial conditions for each trajectory
177-
iout = 1
178-
new_prob = prob.prob_func(prob, i, false)
179-
xv .= new_prob.u0
180-
traj[:,1] = xv
181-
182-
# push velocity back in time by 1/2 dt
183-
update_velocity!(xv, paramBoris, p, -0.5*dt)
184-
185-
for it in 1:nt
186-
update_velocity!(xv, paramBoris, p, dt)
187-
update_location!(xv, dt)
188-
if it % savestepinterval == 0
189-
iout += 1
190-
traj[:,iout] .= xv
149+
"Prepare for advancing."
150+
function _prepare(prob, trajectories, savestepinterval)
151+
(; tspan, dt) = prob
152+
ttotal = tspan[2] - tspan[1]
153+
nt = round(Int, ttotal / dt)
154+
nout = nt ÷ savestepinterval + 1
155+
sols = Vector{TraceSolution}(undef, trajectories)
156+
157+
sols, ttotal, nt, nout
158+
end
159+
160+
function _boris!(sols, prob, irange, savestepinterval, ttotal, nt, nout, isoutofdomain)
161+
(; tspan, dt, p, u0) = prob
162+
paramBoris = BorisMethod()
163+
xv = similar(u0)
164+
traj = zeros(eltype(u0), 6, nout)
165+
166+
for i in irange
167+
# set initial conditions for each trajectory i
168+
iout = 1
169+
new_prob = prob.prob_func(prob, i, false)
170+
xv .= new_prob.u0
171+
traj[:,1] = xv
172+
173+
# push velocity back in time by 1/2 dt
174+
update_velocity!(xv, paramBoris, p, -0.5*dt)
175+
176+
for it in 1:nt
177+
update_velocity!(xv, paramBoris, p, dt)
178+
update_location!(xv, dt)
179+
if it % savestepinterval == 0
180+
iout += 1
181+
traj[:,iout] .= xv
182+
end
183+
isoutofdomain(xv) && break
191184
end
192-
isoutofdomain(xv) && break
193-
end
194185

195-
if iout == nout # regular termination
196-
dtfinal = ttotal - nt*dt
197-
if dtfinal > 1e-3 # final step if needed
198-
update_velocity!(xv, paramBoris, p, dtfinal)
199-
update_location!(xv, dtfinal)
200-
traj_save = hcat(traj, xv)
201-
t = [collect(tspan[1]:dt*savestepinterval:tspan[2])..., tspan[2]]
202-
else
203-
traj_save = copy(traj)
204-
t = collect(tspan[1]:dt*savestepinterval:tspan[2])
186+
if iout == nout # regular termination
187+
dtfinal = ttotal - nt*dt
188+
if dtfinal > 1e-3 # final step if needed
189+
update_velocity!(xv, paramBoris, p, dtfinal)
190+
update_location!(xv, dtfinal)
191+
traj_save = hcat(traj, xv)
192+
t = [collect(tspan[1]:dt*savestepinterval:tspan[2])..., tspan[2]]
193+
else
194+
traj_save = copy(traj)
195+
t = collect(tspan[1]:dt*savestepinterval:tspan[2])
196+
end
197+
else # early termination or savestepinterval != 1
198+
traj_save = traj[:, 1:iout]
199+
t = collect(tspan[1]:dt*savestepinterval:tspan[1]+dt*savestepinterval*iout)
205200
end
206-
else # early termination or savestepinterval != 1
207-
traj_save = traj[:, 1:iout]
208-
t = collect(tspan[1]:dt*savestepinterval:tspan[1]+dt*savestepinterval*iout)
201+
sols[i] = TraceSolution(traj_save, t)
209202
end
210-
sols[i] = TraceSolution(traj_save, t)
211203

212204
return
213205
end

test/runtests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,8 @@ end
393393

394394
prob = TraceProblem(stateinit, tspan, dt, param; prob_func=prob_func_boris)
395395
trajectories = 4
396-
sols = TestParticle.solve(prob, EnsembleThreads();
397-
savestepinterval=1000, trajectories)
396+
savestepinterval = 1000
397+
sols = TestParticle.solve(prob, EnsembleThreads(); savestepinterval, trajectories)
398398
@test sum(x -> sum(@view x.u[:,end]), sols) -1.4065273620640622e6
399399
end
400400
end

0 commit comments

Comments
 (0)