@@ -124,90 +124,82 @@ function solve(prob::TraceProblem, ensemblealg::BasicEnsembleAlgorithm=EnsembleS
124
124
trajectories:: Int = 1 , savestepinterval:: Int = 1 ,
125
125
isoutofdomain:: Function = ODE_DEFAULT_ISOUTOFDOMAIN)
126
126
127
- sols = __solve (ensemblealg, prob, trajectories, savestepinterval, isoutofdomain)
127
+ sols = _solve (ensemblealg, prob, trajectories, savestepinterval, isoutofdomain)
128
128
end
129
129
130
- function __solve (:: EnsembleSerial , prob, trajectories, savestepinterval, isoutofdomain)
131
- sols = Vector {TraceSolution} (undef, trajectories)
132
- # prepare advancing
133
- (; tspan, dt, p) = prob
134
- ttotal = tspan[2 ] - tspan[1 ]
135
- nt = round (Int, ttotal / dt)
136
- nout = nt ÷ savestepinterval + 1
137
- traj = zeros (eltype (prob. u0), 6 , nout)
138
-
139
- paramBoris = BorisMethod ()
140
- xv = similar (prob. u0)
141
- nout = nt ÷ savestepinterval + 1
130
+ function _solve (:: EnsembleSerial , prob, trajectories, savestepinterval, isoutofdomain)
131
+ sols, ttotal, nt, nout = _prepare (prob, trajectories, savestepinterval)
142
132
143
- for i in 1 : trajectories
144
- _boris! (sols, prob, i, xv, paramBoris, p, dt, savestepinterval, nout, ttotal, nt,
145
- traj, tspan, isoutofdomain)
146
- end
133
+ _boris! (sols, prob, 1 : trajectories, savestepinterval, ttotal, nt, nout, isoutofdomain)
147
134
148
135
sols
149
136
end
150
137
151
- function __solve (:: EnsembleThreads , prob, trajectories, savestepinterval, isoutofdomain)
152
- sols = Vector {TraceSolution} (undef, trajectories)
153
- # prepare advancing
154
- (; tspan, dt, p) = prob
155
- ttotal = tspan[2 ] - tspan[1 ]
156
- nt = round (Int, ttotal / dt)
157
- nout = nt ÷ savestepinterval + 1
138
+ function _solve (:: EnsembleThreads , prob, trajectories, savestepinterval, isoutofdomain)
139
+ sols, ttotal, nt, nout = _prepare (prob, trajectories, savestepinterval)
158
140
159
141
nchunks = Threads. nthreads ()
160
142
Threads. @threads for (irange, ichunk) in chunks (1 : trajectories, nchunks)
161
- paramBoris = BorisMethod ()
162
- xv = similar (prob. u0)
163
- traj = zeros (eltype (prob. u0), 6 , nout)
164
-
165
- for i in irange
166
- _boris! (sols, prob, i, xv, paramBoris, p, dt, savestepinterval, nout, ttotal, nt,
167
- traj, tspan, isoutofdomain)
168
- end
143
+ _boris! (sols, prob, irange, savestepinterval, ttotal, nt, nout, isoutofdomain)
169
144
end
170
145
171
146
sols
172
147
end
173
148
174
- function _boris! (sols, prob, i,
175
- xv, paramBoris, p, dt, savestepinterval, nout, ttotal, nt, traj, tspan, isoutofdomain)
176
- # set initial conditions for each trajectory
177
- iout = 1
178
- new_prob = prob. prob_func (prob, i, false )
179
- xv .= new_prob. u0
180
- traj[:,1 ] = xv
181
-
182
- # push velocity back in time by 1/2 dt
183
- update_velocity! (xv, paramBoris, p, - 0.5 * dt)
184
-
185
- for it in 1 : nt
186
- update_velocity! (xv, paramBoris, p, dt)
187
- update_location! (xv, dt)
188
- if it % savestepinterval == 0
189
- iout += 1
190
- traj[:,iout] .= xv
149
+ " Prepare for advancing."
150
+ function _prepare (prob, trajectories, savestepinterval)
151
+ (; tspan, dt) = prob
152
+ ttotal = tspan[2 ] - tspan[1 ]
153
+ nt = round (Int, ttotal / dt)
154
+ nout = nt ÷ savestepinterval + 1
155
+ sols = Vector {TraceSolution} (undef, trajectories)
156
+
157
+ sols, ttotal, nt, nout
158
+ end
159
+
160
+ function _boris! (sols, prob, irange, savestepinterval, ttotal, nt, nout, isoutofdomain)
161
+ (; tspan, dt, p, u0) = prob
162
+ paramBoris = BorisMethod ()
163
+ xv = similar (u0)
164
+ traj = zeros (eltype (u0), 6 , nout)
165
+
166
+ for i in irange
167
+ # set initial conditions for each trajectory i
168
+ iout = 1
169
+ new_prob = prob. prob_func (prob, i, false )
170
+ xv .= new_prob. u0
171
+ traj[:,1 ] = xv
172
+
173
+ # push velocity back in time by 1/2 dt
174
+ update_velocity! (xv, paramBoris, p, - 0.5 * dt)
175
+
176
+ for it in 1 : nt
177
+ update_velocity! (xv, paramBoris, p, dt)
178
+ update_location! (xv, dt)
179
+ if it % savestepinterval == 0
180
+ iout += 1
181
+ traj[:,iout] .= xv
182
+ end
183
+ isoutofdomain (xv) && break
191
184
end
192
- isoutofdomain (xv) && break
193
- end
194
185
195
- if iout == nout # regular termination
196
- dtfinal = ttotal - nt* dt
197
- if dtfinal > 1e-3 # final step if needed
198
- update_velocity! (xv, paramBoris, p, dtfinal)
199
- update_location! (xv, dtfinal)
200
- traj_save = hcat (traj, xv)
201
- t = [collect (tspan[1 ]: dt* savestepinterval: tspan[2 ])... , tspan[2 ]]
202
- else
203
- traj_save = copy (traj)
204
- t = collect (tspan[1 ]: dt* savestepinterval: tspan[2 ])
186
+ if iout == nout # regular termination
187
+ dtfinal = ttotal - nt* dt
188
+ if dtfinal > 1e-3 # final step if needed
189
+ update_velocity! (xv, paramBoris, p, dtfinal)
190
+ update_location! (xv, dtfinal)
191
+ traj_save = hcat (traj, xv)
192
+ t = [collect (tspan[1 ]: dt* savestepinterval: tspan[2 ])... , tspan[2 ]]
193
+ else
194
+ traj_save = copy (traj)
195
+ t = collect (tspan[1 ]: dt* savestepinterval: tspan[2 ])
196
+ end
197
+ else # early termination or savestepinterval != 1
198
+ traj_save = traj[:, 1 : iout]
199
+ t = collect (tspan[1 ]: dt* savestepinterval: tspan[1 ]+ dt* savestepinterval* iout)
205
200
end
206
- else # early termination or savestepinterval != 1
207
- traj_save = traj[:, 1 : iout]
208
- t = collect (tspan[1 ]: dt* savestepinterval: tspan[1 ]+ dt* savestepinterval* iout)
201
+ sols[i] = TraceSolution (traj_save, t)
209
202
end
210
- sols[i] = TraceSolution (traj_save, t)
211
203
212
204
return
213
205
end
0 commit comments