diff --git a/src/default.jl b/src/default.jl index 90ac0d9ec..44ac26cad 100644 --- a/src/default.jl +++ b/src/default.jl @@ -138,8 +138,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::FastShortcutN ) end -function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::FastShortcutNonlinearPolyalg, args...; - kwargs...) where {uType, iip} +function SciMLBase.__solve(prob::NonlinearProblem{uType, false}, alg::FastShortcutNonlinearPolyalg, args...; + kwargs...) where {uType} adkwargs = alg.adkwargs linsolve = alg.linsolve @@ -190,6 +190,58 @@ function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::FastShortcut end +function SciMLBase.__solve(prob::NonlinearProblem{uType, true}, alg::FastShortcutNonlinearPolyalg, args...; + kwargs...) where {uType} + + adkwargs = alg.adkwargs + linsolve = alg.linsolve + precs = alg.precs + + sol1 = SciMLBase.__solve(prob, NewtonRaphson(;linsolve, precs, adkwargs...), args...; kwargs...) + if SciMLBase.successful_retcode(sol1) + return SciMLBase.build_solution(prob, alg, sol1.u, sol1.resid; + sol1.retcode, sol1.stats) + end + + sol2 = SciMLBase.__solve(prob, NewtonRaphson(;linsolve, precs, linesearch=BackTracking(), adkwargs...), args...; kwargs...) + if SciMLBase.successful_retcode(sol2) + return SciMLBase.build_solution(prob, alg, sol2.u, sol2.resid; + sol2.retcode, sol2.stats) + end + + sol3 = SciMLBase.__solve(prob, TrustRegion(;linsolve, precs, adkwargs...), args...; kwargs...) + if SciMLBase.successful_retcode(sol3) + return SciMLBase.build_solution(prob, alg, sol3.u, sol3.resid; + sol3.retcode, sol3.stats) + end + + sol4 = SciMLBase.__solve(prob, TrustRegion(;linsolve, precs, radius_update_scheme = RadiusUpdateSchemes.Bastin, adkwargs...), args...; kwargs...) + if SciMLBase.successful_retcode(sol4) + return SciMLBase.build_solution(prob, alg, sol4.u, sol4.resid; + sol4.retcode, sol4.stats) + end + + resids = (sol1.resid, sol2.resid, sol3.resid, sol4.resid) + minfu, idx = findmin(DEFAULT_NORM, resids) + + if idx == 1 + SciMLBase.build_solution(prob, alg, sol1.u, sol1.resid; + sol1.retcode, sol1.stats) + elseif idx == 2 + SciMLBase.build_solution(prob, alg, sol2.u, sol2.resid; + sol2.retcode, sol2.stats) + elseif idx == 3 + SciMLBase.build_solution(prob, alg, sol3.u, sol3.resid; + sol3.retcode, sol3.stats) + elseif idx == 4 + SciMLBase.build_solution(prob, alg, sol4.u, sol4.resid; + sol4.retcode, sol4.stats) + else + error("Unreachable reached, 박정석") + end + +end + ## General shared polyalg functions function perform_step!(cache::Union{RobustMultiNewtonCache, FastShortcutNonlinearPolyalgCache}) diff --git a/test/polyalgs.jl b/test/polyalgs.jl index 110f6228a..30890e0ae 100644 --- a/test/polyalgs.jl +++ b/test/polyalgs.jl @@ -5,4 +5,18 @@ u0 = [1.0, 1.0] probN = NonlinearProblem(f, u0) @time solver = solve(probN, abstol = 1e-9) @time solver = solve(probN, RobustMultiNewton(), abstol = 1e-9) -@time solver = solve(probN, FastShortcutNonlinearPolyalg(), abstol = 1e-9) \ No newline at end of file +@time solver = solve(probN, FastShortcutNonlinearPolyalg(), abstol = 1e-9) + +# https://github.com/SciML/NonlinearSolve.jl/issues/153 + +function f(du, u, p) + s1, s1s2, s2 = u + k1, c1, Δt = p + + du[1] = -0.25 * c1 * k1 * s1 * s2 + du[2] = 0.25 * c1 * k1 * s1 * s2 + du[3] = -0.25 * c1 * k1 * s1 * s2 +end + +prob = NonlinearProblem(f, [2.0,2.0,2.0], [1.0, 2.0, 2.5]) +sol = solve(prob) \ No newline at end of file