From c428bd918379700a940b6e9082b16c0a690c254f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Oct 2023 20:02:53 -0400 Subject: [PATCH] Fix 23 test problems --- src/broyden.jl | 5 ++--- src/dfsane.jl | 27 +++++++++------------------ src/gaussnewton.jl | 16 +++++----------- src/klement.jl | 5 ++--- src/pseudotransient.jl | 24 ++---------------------- src/raphson.jl | 5 ++--- src/trustRegion.jl | 5 ++--- src/utils.jl | 36 +++++++++++++++++++----------------- test/23_test_problems.jl | 3 ++- 9 files changed, 45 insertions(+), 81 deletions(-) diff --git a/src/broyden.jl b/src/broyden.jl index 557acb813..5fcbd3d51 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -185,9 +185,8 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca cache.u = u0 cache.fu = cache.f(cache.u, p) end - termination_condition = _get_reinit_termination_condition(cache, - abstol, - reltol, + + termination_condition = _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) cache.abstol = abstol diff --git a/src/dfsane.jl b/src/dfsane.jl index d88f18db0..aca13c344 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -114,12 +114,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args. ๐’น, uโ‚™โ‚‹โ‚, fuโ‚™, fuโ‚™โ‚‹โ‚ = copy(uโ‚™), copy(uโ‚™), copy(uโ‚™), copy(uโ‚™) if iip - # f = (dx, x) -> prob.f(dx, x, p) - # f(fuโ‚™โ‚‹โ‚, uโ‚™โ‚‹โ‚) prob.f(fuโ‚™โ‚‹โ‚, uโ‚™โ‚‹โ‚, p) else - # f = (x) -> prob.f(x, p) - fuโ‚™โ‚‹โ‚ = prob.f(uโ‚™โ‚‹โ‚, p) # f(uโ‚™โ‚‹โ‚) + fuโ‚™โ‚‹โ‚ = prob.f(uโ‚™โ‚‹โ‚, p) end fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™โ‚‹โ‚ = norm(fuโ‚™โ‚‹โ‚)^nโ‚‘โ‚“โ‚š @@ -127,10 +124,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args. โ„‹ = fill(fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™โ‚‹โ‚, M) - abstol, reltol, termination_condition = _init_termination_elements(abstol, - reltol, - termination_condition, - T) + abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol, + termination_condition, T) mode = DiffEqBase.get_termination_mode(termination_condition) @@ -167,14 +162,13 @@ function perform_step!(cache::DFSaneCache{true}) f(cache.fuโ‚™, cache.uโ‚™) fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™ = norm(cache.fuโ‚™)^nโ‚‘โ‚“โ‚š - for _ in 1:(cache.alg.max_inner_iterations) + for jjj in 1:(cache.alg.max_inner_iterations) ๐’ธ = fฬ„ + ฮท - ฮณ * ฮฑโ‚Š^2 * fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™โ‚‹โ‚ fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™ โ‰ค ๐’ธ && break ฮฑโ‚Š = clamp(ฮฑโ‚Š^2 * fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™โ‚‹โ‚ / (fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™ + (T(2) * ฮฑโ‚Š - T(1)) * fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™โ‚‹โ‚), - ฯ„โ‚˜แตขโ‚™ * ฮฑโ‚Š, - ฯ„โ‚˜โ‚โ‚“ * ฮฑโ‚Š) + ฯ„โ‚˜แตขโ‚™ * ฮฑโ‚Š, ฯ„โ‚˜โ‚โ‚“ * ฮฑโ‚Š) @. cache.uโ‚™ = cache.uโ‚™โ‚‹โ‚ - ฮฑโ‚‹ * cache.๐’น f(cache.fuโ‚™, cache.uโ‚™) @@ -183,8 +177,7 @@ function perform_step!(cache::DFSaneCache{true}) fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™ .โ‰ค ๐’ธ && break ฮฑโ‚‹ = clamp(ฮฑโ‚‹^2 * fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™โ‚‹โ‚ / (fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™ + (T(2) * ฮฑโ‚‹ - T(1)) * fโ‚โ‚™โ‚’แตฃโ‚˜โ‚Žโ‚™โ‚‹โ‚), - ฯ„โ‚˜แตขโ‚™ * ฮฑโ‚‹, - ฯ„โ‚˜โ‚โ‚“ * ฮฑโ‚‹) + ฯ„โ‚˜แตขโ‚™ * ฮฑโ‚‹, ฯ„โ‚˜โ‚โ‚“ * ฮฑโ‚‹) @. cache.uโ‚™ = cache.uโ‚™โ‚‹โ‚ + ฮฑโ‚Š * cache.๐’น f(cache.fuโ‚™, cache.uโ‚™) @@ -207,7 +200,7 @@ function perform_step!(cache::DFSaneCache{true}) # Spectral parameter bounds check if abs(cache.ฯƒโ‚™) > ฯƒโ‚˜โ‚โ‚“ || abs(cache.ฯƒโ‚™) < ฯƒโ‚˜แตขโ‚™ test_norm = sqrt(sum(abs2, cache.fuโ‚™โ‚‹โ‚)) - cache.ฯƒโ‚™ = clamp(1.0 / test_norm, 1, 1e5) + cache.ฯƒโ‚™ = clamp(T(1) / test_norm, T(1), T(1e5)) end # Take step @@ -283,7 +276,7 @@ function perform_step!(cache::DFSaneCache{false}) # Spectral parameter bounds check if abs(cache.ฯƒโ‚™) > ฯƒโ‚˜โ‚โ‚“ || abs(cache.ฯƒโ‚™) < ฯƒโ‚˜แตขโ‚™ test_norm = sqrt(sum(abs2, cache.fuโ‚™โ‚‹โ‚)) - cache.ฯƒโ‚™ = clamp(1.0 / test_norm, 1, 1e5) + cache.ฯƒโ‚™ = clamp(T(1) / test_norm, T(1), T(1e5)) end # Take step @@ -337,9 +330,7 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uโ‚™; p = cache.p T = eltype(cache.uโ‚™) cache.ฯƒโ‚™ = T(cache.alg.ฯƒ_1) - termination_condition = _get_reinit_termination_condition(cache, - abstol, - reltol, + termination_condition = _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) cache.abstol = abstol diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 61ce98c76..b066f6169 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -118,10 +118,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: nothing return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf, - linsolve, J, - Jแต€J, Jแต€f, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, - reltol, - prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition) + linsolve, J, Jแต€J, Jแต€f, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, + abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition) end function perform_step!(cache::GaussNewtonCache{true}) @@ -147,10 +145,7 @@ function perform_step!(cache::GaussNewtonCache{true}) @. u = u - du f(cache.fu_new, u, p) - (termination_condition(cache.fu_new .- cache.fu1, - cache.u, - u_prev, - cache.abstol, + (termination_condition(cache.fu_new .- cache.fu1, cache.u, u_prev, cache.abstol, cache.reltol) || termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) && (cache.force_stop = true) @@ -217,9 +212,8 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache cache.u = u0 cache.fu1 = cache.f(cache.u, p) end - termination_condition = _get_reinit_termination_condition(cache, - abstol, - reltol, + + termination_condition = _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) cache.abstol = abstol diff --git a/src/klement.jl b/src/klement.jl index 8fc44be59..e60aeee9b 100644 --- a/src/klement.jl +++ b/src/klement.jl @@ -238,10 +238,9 @@ function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = ca cache.fu = cache.f(cache.u, p) end - termination_condition = _get_reinit_termination_condition(cache, - abstol, - reltol, + termination_condition = _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) + cache.abstol = abstol cache.reltol = reltol cache.termination_condition = termination_condition diff --git a/src/pseudotransient.jl b/src/pseudotransient.jl index 306e0758d..64e4f258c 100644 --- a/src/pseudotransient.jl +++ b/src/pseudotransient.jl @@ -51,7 +51,7 @@ function PseudoTransient(; concrete_jac = nothing, linsolve = nothing, return PseudoTransient{_unwrap_val(concrete_jac)}(ad, linsolve, precs, alpha_initial) end -@concrete mutable struct PseudoTransientCache{iip} +@concrete mutable struct PseudoTransientCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u @@ -78,8 +78,6 @@ end tc_storage end -isinplace(::PseudoTransientCache{iip}) where {iip} = iip - function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransient, args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, termination_condition = nothing, internalnorm = DEFAULT_NORM, @@ -174,22 +172,6 @@ function perform_step!(cache::PseudoTransientCache{false}) return nothing end -function SciMLBase.solve!(cache::PseudoTransientCache) - while !cache.force_stop && cache.stats.nsteps < cache.maxiters - perform_step!(cache) - cache.stats.nsteps += 1 - end - - if cache.stats.nsteps == cache.maxiters - cache.retcode = ReturnCode.MaxIters - else - cache.retcode = ReturnCode.Success - end - - return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1; - cache.retcode, cache.stats) -end - function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p, alpha_new, abstol = cache.abstol, reltol = cache.reltol, @@ -205,9 +187,7 @@ function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = c cache.fu1 = cache.f(cache.u, p) end - termination_condition = _get_reinit_termination_condition(cache, - abstol, - reltol, + termination_condition = _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) cache.alpha = convert(eltype(cache.u), alpha_new) diff --git a/src/raphson.jl b/src/raphson.jl index a34d860ce..6e2a502bb 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -179,10 +179,9 @@ function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cac cache.fu1 = cache.f(cache.u, p) end - termination_condition = _get_reinit_termination_condition(cache, - abstol, - reltol, + termination_condition = _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) + cache.abstol = abstol cache.reltol = reltol cache.termination_condition = termination_condition diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 3b14d0a38..cf9f41af0 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -736,9 +736,8 @@ function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache cache.u = u0 cache.fu = cache.f(cache.u, p) end - termination_condition = _get_reinit_termination_condition(cache, - abstol, - reltol, + + termination_condition = _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) cache.abstol = abstol diff --git a/src/utils.jl b/src/utils.jl index 718eef22f..87a80e4ed 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -218,14 +218,16 @@ end function _init_termination_elements(abstol, reltol, termination_condition, ::Type{T}; mode = NLSolveTerminationMode.AbsNorm) where {T} if termination_condition !== nothing - abstol !== nothing ? - (abstol != termination_condition.abstol ? - error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") : - nothing) : nothing - reltol !== nothing ? - (reltol != termination_condition.abstol ? - error("Incompatible relative tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") : - nothing) : nothing + if abstol !== nothing && abstol != termination_condition.abstol + error("Incompatible absolute tolerances found. The tolerances supplied as the \ + keyword argument and the one supplied in the termination condition should \ + be same.") + end + if reltol !== nothing && reltol != termination_condition.reltol + error("Incompatible relative tolerances found. The tolerances supplied as the \ + keyword argument and the one supplied in the termination condition should \ + be same.") + end abstol = _get_tolerance(abstol, termination_condition.abstol, T) reltol = _get_tolerance(reltol, termination_condition.reltol, T) return abstol, reltol, termination_condition @@ -239,18 +241,18 @@ end function _get_reinit_termination_condition(cache, abstol, reltol, termination_condition) if termination_condition != cache.termination_condition - if abstol != cache.abstol - if abstol != termination_condition.abstol - error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") - end + if abstol != cache.abstol && abstol != termination_condition.abstol + error("Incompatible absolute tolerances found. The tolerances supplied as the \ + keyword argument and the one supplied in the termination condition \ + should be same.") end - if reltol != cache.reltol - if reltol != termination_condition.reltol - error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") - end + if reltol != cache.reltol && reltol != termination_condition.reltol + error("Incompatible absolute tolerances found. The tolerances supplied as the \ + keyword argument and the one supplied in the termination condition \ + should be same.") end - termination_condition + return termination_condition else # Build the termination_condition with new abstol and reltol return NLSolveTerminationCondition{ diff --git a/test/23_test_problems.jl b/test/23_test_problems.jl index 091088ab2..77274d34a 100644 --- a/test/23_test_problems.jl +++ b/test/23_test_problems.jl @@ -11,8 +11,9 @@ function test_on_library(problems, dicts, alg_ops, broken_tests, ฯต = 1e-4) @testset "$idx: $(dict["title"])" begin for alg in alg_ops try - sol = solve(nlprob, alg, abstol = 1e-18, reltol = 1e-18) + sol = solve(nlprob, alg) problem(res, sol.u, nothing) + broken = idx in broken_tests[alg] ? true : false @test norm(res)โ‰คฯต broken=broken catch