From 7a9b55614f544d655277d9531e8f49428de2522d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Oct 2023 18:21:13 -0400 Subject: [PATCH] Make it generated --- src/default.jl | 93 ++++++++++++++++++++++++++------------------------ src/dfsane.jl | 4 +-- 2 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/default.jl b/src/default.jl index b7c16d735..913f994fe 100644 --- a/src/default.jl +++ b/src/default.jl @@ -185,11 +185,14 @@ end end) end - resids = map(x -> "$x.resid", sol_syms) + resids = map(x -> Symbol("$(x)_resid"), sol_syms) + for (sym, resid) in zip(sol_syms, resids) + push!(calls, :($(resid) = $(sym).resid)) + end push!(calls, quote - resids = $(Tuple(resids)) + resids = tuple($(Tuple(resids)...)) minfu, idx = findmin(DEFAULT_NORM, resids) end) @@ -198,8 +201,7 @@ end quote if idx == $i return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u, - $(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats, - original = $(sol_syms[i])) + $(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats) end end) end @@ -210,53 +212,54 @@ end ## General shared polyalg functions -function perform_step!(cache::Union{RobustMultiNewtonCache, - FastShortcutNonlinearPolyalgCache}) - current = cache.current - 1 ≤ current ≤ length(cache.caches) || error("Current choices shouldn't get here!") - - current_cache = cache.caches[current] - while not_terminated(current_cache) - perform_step!(current_cache) +@generated function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache{iip, N}, + FastShortcutNonlinearPolyalgCache{iip, N}}) where {iip, N} + calls = [ + quote + 1 ≤ cache.current ≤ length(cache.caches) || + error("Current choices shouldn't get here!") + end, + ] + + cache_syms = [gensym("cache") for i in 1:N] + sol_syms = [gensym("sol") for i in 1:N] + for i in 1:N + push!(calls, + quote + $(cache_syms[i]) = cache.caches[$(i)] + if $(i) == cache.current + $(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i])) + if SciMLBase.successful_retcode($(sol_syms[i])) + stats = $(sol_syms[i]).stats + u = $(sol_syms[i]).u + fu = get_fu($(cache_syms[i])) + return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u, + fu; retcode = ReturnCode.Success, stats, + original = $(sol_syms[i])) + end + cache.current = $(i + 1) + end + end) end - return nothing -end - -function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache, - FastShortcutNonlinearPolyalgCache}) - current = cache.current - 1 ≤ current ≤ length(cache.caches) || error("Current choices shouldn't get here!") - - current_cache = cache.caches[current] - while current ≤ length(cache.caches) # && !all(terminated[current:end]) - sol_tmp = solve!(current_cache) - SciMLBase.successful_retcode(sol_tmp) && break - current += 1 - cache.current = current - current_cache = cache.caches[current] + resids = map(x -> Symbol("$(x)_resid"), cache_syms) + for (sym, resid) in zip(cache_syms, resids) + push!(calls, :($(resid) = get_fu($(sym)))) end + push!(calls, + quote + retcode = ReturnCode.MaxIters - if current ≤ length(cache.caches) - retcode = ReturnCode.Success - - stats = cache.caches[current].stats - u = cache.caches[current].u - fu = get_fu(cache.caches[current]) - - return SciMLBase.build_solution(cache.caches[1].prob, cache.alg, u, fu; - retcode, stats) - else - retcode = ReturnCode.MaxIters + fus = tuple($(Tuple(resids)...)) + minfu, idx = findmin(cache.caches[1].internalnorm, fus) + stats = cache.caches[idx].stats + u = cache.caches[idx].u - fus = get_fu.(cache.caches) - minfu, idx = findmin(cache.caches[1].internalnorm, fus) - stats = cache.caches[idx].stats - u = cache.caches[idx].u + return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, + fus[idx]; retcode, stats) + end) - return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx]; - retcode, stats) - end + return Expr(:block, calls...) end function SciMLBase.reinit!(cache::Union{RobustMultiNewtonCache, diff --git a/src/dfsane.jl b/src/dfsane.jl index a4bf30c8a..13de5ff6a 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -131,7 +131,7 @@ end function perform_step!(cache::DFSaneCache{true}) @unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache - f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p) + f = (dx, x) -> cache.prob.f(dx, x, cache.p) T = eltype(cache.uₙ) n = cache.stats.nsteps @@ -208,7 +208,7 @@ end function perform_step!(cache::DFSaneCache{false}) @unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache - f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p) + f = x -> cache.prob.f(x, cache.p) T = eltype(cache.uₙ) n = cache.stats.nsteps