diff --git a/src/default.jl b/src/default.jl index c7092341b..74ef935d7 100644 --- a/src/default.jl +++ b/src/default.jl @@ -65,6 +65,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::RobustMultiNe TrustRegion(; linsolve, precs, radius_update_scheme = RadiusUpdateSchemes.Fan, adkwargs...)) + # Partially Type Unstable but can't do much since some upstream caches -- LineSearches + # and SparseDiffTools cause the instability return RobustMultiNewtonCache{iip}(map(solver -> SciMLBase.__init(prob, solver, args...; kwargs...), algs), alg, 1) end @@ -139,36 +141,56 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, end # This version doesn't allocate all the caches! -function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, +@generated function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::FastShortcutNonlinearPolyalg, args...; kwargs...) where {uType, iip} - @unpack adkwargs, linsolve, precs = alg + calls = [:(@unpack adkwargs, linsolve, precs = alg)] algs = [ - iip ? Klement() : nothing, # Klement not yet implemented for IIP - iip ? Broyden() : nothing, # Broyden not yet implemented for IIP - NewtonRaphson(; linsolve, precs, adkwargs...), - NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...), - TrustRegion(; linsolve, precs, adkwargs...), - TrustRegion(; linsolve, precs, - radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...), + iip ? :(Klement()) : nothing, # Klement not yet implemented for IIP + iip ? :(Broyden()) : nothing, # Broyden not yet implemented for IIP + :(NewtonRaphson(; linsolve, precs, adkwargs...)), + :(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)), + :(TrustRegion(; linsolve, precs, adkwargs...)), + :(TrustRegion(; linsolve, precs, + radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...)), ] filter!(!isnothing, algs) - - sols = Vector{SciMLBase.NonlinearSolution}(undef, length(algs)) - - for (i, solver) in enumerate(algs) - sols[i] = SciMLBase.__solve(prob, solver, args...; kwargs...) - if SciMLBase.successful_retcode(sols[i]) - return SciMLBase.build_solution(prob, alg, sols[i].u, sols[i].resid; - sols[i].retcode, sols[i].stats, original = sols[i]) - end + counter = 1 + sol_syms = [gensym("sol") for i in 1:length(algs)] + for i in 1:length(algs) + cur_sol = sol_syms[i] + push!(calls, + quote + $(cur_sol) = SciMLBase.__solve(prob, $(algs[i]), args...; kwargs...) + if SciMLBase.successful_retcode($(cur_sol)) + return SciMLBase.build_solution(prob, alg, $(cur_sol).u, + $(cur_sol).resid; $(cur_sol).retcode, $(cur_sol).stats, + original = $(cur_sol)) + end + end) end - resids = map(Base.Fix2(getproperty, resid), sols) - minfu, idx = findmin(DEFAULT_NORM, resids) + resids = map(x -> "$x.resid", sol_syms) + + push!(calls, + quote + resids = $(Tuple(resids)) + minfu, idx = findmin(DEFAULT_NORM, resids) + end) + + for i in 1:length(algs) + push!(calls, + quote + if idx == $i + return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u, + $(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats, + original = $(sol_syms[i])) + end + end) + end + push!(calls, :(error("Current choices shouldn't get here!"))) - return SciMLBase.build_solution(prob, alg, sols[idx].u, sols[idx].resid; - sols[idx].retcode, sols[idx].stats, original = sols[idx]) + return Expr(:block, calls...) end ## General shared polyalg functions