Skip to content

Commit 7439142

Browse files
authored
Refactor Boris ensemble tracing (#128)
* Reduce allocations * Add Boris ensemble case for benchmark
1 parent 34ec481 commit 7439142

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

benchmark/benchmarks.jl

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ end
5656
SUITE["trace"]["numerical field"]["in place"] = @benchmarkable solve($prob_ip, Tsit5(); save_idxs=[1,2,3])
5757
SUITE["trace"]["numerical field"]["out of place"] = @benchmarkable solve($prob_oop, Tsit5(); save_idxs=[1,2,3])
5858
SUITE["trace"]["numerical field"]["Boris"] = @benchmarkable trace_trajectory($prob_boris; savestepinterval=10)
59+
SUITE["trace"]["numerical field"]["Boris ensemble"] = @benchmarkable trace_trajectory($prob_boris; savestepinterval=10, trajectories=2)
5960

6061
param_td = prepare(E_td, B_td, F_td)
6162
prob_ip = ODEProblem(trace!, stateinit, tspan, param_td) # in place

src/pusher.jl

+14-11
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,18 @@ function trace_trajectory(prob::TraceProblem; trajectories::Int=1,
130130
savestepinterval::Int=1, isoutofdomain::Function=ODE_DEFAULT_ISOUTOFDOMAIN)
131131

132132
sols = Vector{TraceSolution}(undef, trajectories)
133+
# prepare advancing
134+
xv = similar(prob.u0)
135+
(; tspan, dt, p) = prob
136+
ttotal = tspan[2] - tspan[1]
137+
nt = Int(ttotal ÷ dt)
138+
iout, nout = 1, nt ÷ savestepinterval + 1
139+
traj = zeros(eltype(prob.u0), 6, nout)
133140

134141
for i in 1:trajectories
142+
# set initial conditions for each trajectory
135143
new_prob = prob.prob_func(prob, i, false)
136-
(; u0, tspan, dt, p) = new_prob
137-
xv = copy(u0)
138-
# prepare advancing
139-
ttotal = tspan[2] - tspan[1]
140-
nt = Int(ttotal ÷ dt)
141-
iout, nout = 1, nt ÷ savestepinterval + 1
142-
traj = zeros(eltype(u0), 6, nout)
143-
144+
xv .= new_prob.u0
144145
traj[:,1] = xv
145146

146147
# push velocity back in time by 1/2 dt
@@ -161,16 +162,18 @@ function trace_trajectory(prob::TraceProblem; trajectories::Int=1,
161162
if dtfinal > 1e-3 # final step if needed
162163
update_velocity!(xv, p, dtfinal)
163164
update_location!(xv, dtfinal)
164-
traj = hcat(traj, xv)
165+
traj_save = hcat(traj, xv)
165166
t = [collect(tspan[1]:dt:tspan[2])..., tspan[2]]
166167
else
168+
traj_save = traj
167169
t = collect(tspan[1]:dt:tspan[2])
168170
end
169171
else # early termination
170-
traj = traj[:, 1:iout]
172+
traj_save = traj[:, 1:iout]
171173
t = collect(range(tspan[1], tspan[1]+iout*dt, step=dt))
172174
end
173-
sols[i] = TraceSolution(traj, t)
175+
sols[i] = TraceSolution(traj_save, t)
176+
iout = 1
174177
end
175178

176179
sols

0 commit comments

Comments
 (0)