@@ -163,8 +163,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
163
163
164
164
linsolve = get_linear_solver (alg. descent)
165
165
initialization_cache = __internal_init (prob, alg. initialization, alg, f, fu, u, p;
166
- linsolve,
167
- maxiters, internalnorm)
166
+ linsolve, maxiters, internalnorm)
168
167
169
168
abstol, reltol, termination_cache = init_termination_cache (abstol, reltol, fu, u,
170
169
termination_condition)
@@ -222,9 +221,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
222
221
new_jacobian = true
223
222
@static_timeit cache. timer " jacobian init/reinit" begin
224
223
if get_nsteps (cache) == 0 # First Step is special ignore kwargs
225
- J_init = __internal_solve! (cache. initialization_cache,
226
- cache. fu,
227
- cache. u,
224
+ J_init = __internal_solve! (cache. initialization_cache, cache. fu, cache. u,
228
225
Val (false ))
229
226
if INV
230
227
if jacobian_initialized_preinverted (cache. initialization_cache. alg)
@@ -283,52 +280,65 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
283
280
@static_timeit cache. timer " descent" begin
284
281
if cache. trustregion_cache != = nothing &&
285
282
hasfield (typeof (cache. trustregion_cache), :trust_region )
286
- δu, descent_success, descent_intermediates = __internal_solve! (cache. descent_cache,
287
- J, cache. fu, cache. u; new_jacobian,
288
- trust_region = cache. trustregion_cache. trust_region)
283
+ descent_result = __internal_solve! (cache. descent_cache, J, cache. fu, cache. u;
284
+ new_jacobian, trust_region = cache. trustregion_cache. trust_region)
289
285
else
290
- δu, descent_success, descent_intermediates = __internal_solve! (cache. descent_cache,
291
- J, cache . fu, cache . u; new_jacobian)
286
+ descent_result = __internal_solve! (cache. descent_cache, J, cache . fu, cache . u;
287
+ new_jacobian)
292
288
end
293
289
end
294
290
295
- if descent_success
296
- if GB === :LineSearch
297
- @static_timeit cache. timer " linesearch" begin
298
- needs_reset, α = __internal_solve! (cache. linesearch_cache, cache. u, δu)
299
- end
300
- if needs_reset && cache. steps_since_last_reset > 5 # Reset after a burn-in period
301
- cache. force_reinit = true
302
- else
303
- @static_timeit cache. timer " step" begin
304
- @bb axpy! (α, δu, cache. u)
305
- evaluate_f! (cache, cache. u, cache. p)
306
- end
307
- end
308
- elseif GB === :TrustRegion
309
- @static_timeit cache. timer " trustregion" begin
310
- tr_accepted, u_new, fu_new = __internal_solve! (cache. trustregion_cache, J,
311
- cache. fu, cache. u, δu, descent_intermediates)
312
- if tr_accepted
313
- @bb copyto! (cache. u, u_new)
314
- @bb copyto! (cache. fu, fu_new)
315
- end
316
- if hasfield (typeof (cache. trustregion_cache), :shrink_counter ) &&
317
- cache. trustregion_cache. shrink_counter > cache. max_shrink_times
318
- cache. retcode = ReturnCode. ShrinkThresholdExceeded
319
- cache. force_stop = true
320
- end
321
- end
322
- α = true
323
- elseif GB === :None
291
+ if descent_result. success
292
+ if GB === :None
324
293
@static_timeit cache. timer " step" begin
325
- @bb axpy! (1 , δu, cache. u)
294
+ if descent_result. u != = missing
295
+ @bb copyto! (cache. u, descent_result. u)
296
+ elseif descent_result. δu != = missing
297
+ @bb axpy! (1 , descent_result. δu, cache. u)
298
+ else
299
+ error (" This shouldn't occur. `$(cache. alg. descent) ` is incorrectly \
300
+ specified." )
301
+ end
326
302
evaluate_f! (cache, cache. u, cache. p)
327
303
end
328
304
α = true
329
305
else
330
- error (" Unknown Globalization Strategy: $(GB) . Allowed values are (:LineSearch, \
331
- :TrustRegion, :None)" )
306
+ δu = descent_result. δu
307
+ @assert δu!= = missing " Descent Supporting LineSearch or TrustRegion must return a `δu`."
308
+
309
+ if GB === :LineSearch
310
+ @static_timeit cache. timer " linesearch" begin
311
+ needs_reset, α = __internal_solve! (cache. linesearch_cache, cache. u, δu)
312
+ end
313
+ if needs_reset && cache. steps_since_last_reset > 5 # Reset after a burn-in period
314
+ cache. force_reinit = true
315
+ else
316
+ @static_timeit cache. timer " step" begin
317
+ @bb axpy! (α, δu, cache. u)
318
+ evaluate_f! (cache, cache. u, cache. p)
319
+ end
320
+ end
321
+ elseif GB === :TrustRegion
322
+ @static_timeit cache. timer " trustregion" begin
323
+ tr_accepted, u_new, fu_new = __internal_solve! (cache. trustregion_cache,
324
+ J, cache. fu, cache. u, δu, descent_result. extras)
325
+ if tr_accepted
326
+ @bb copyto! (cache. u, u_new)
327
+ @bb copyto! (cache. fu, fu_new)
328
+ α = true
329
+ else
330
+ α = false
331
+ end
332
+ if hasfield (typeof (cache. trustregion_cache), :shrink_counter ) &&
333
+ cache. trustregion_cache. shrink_counter > cache. max_shrink_times
334
+ cache. retcode = ReturnCode. ShrinkThresholdExceeded
335
+ cache. force_stop = true
336
+ end
337
+ end
338
+ else
339
+ error (" Unknown Globalization Strategy: $(GB) . Allowed values are \
340
+ (:LineSearch, :TrustRegion, :None)" )
341
+ end
332
342
end
333
343
check_and_update! (cache, cache. fu, cache. u, cache. u_cache)
334
344
else
0 commit comments