Skip to content

Commit

Permalink
Make it generated
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2023
1 parent d40b901 commit 3724fbc
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 47 deletions.
93 changes: 48 additions & 45 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,14 @@ end
end)
end

resids = map(x -> "$x.resid", sol_syms)
resids = map(x -> Symbol("$(x)_resid"), sol_syms)
for (sym, resid) in zip(sol_syms, resids)
push!(calls, :($(resid) = $(sym).resid))
end

push!(calls,
quote
resids = $(Tuple(resids))
resids = tuple($(Tuple(resids)...))
minfu, idx = findmin(DEFAULT_NORM, resids)
end)

Expand All @@ -198,8 +201,7 @@ end
quote
if idx == $i
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
$(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats,
original = $(sol_syms[i]))
$(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats)
end
end)
end
Expand All @@ -210,53 +212,54 @@ end

## General shared polyalg functions

function perform_step!(cache::Union{RobustMultiNewtonCache,
FastShortcutNonlinearPolyalgCache})
current = cache.current
1 current length(cache.caches) || error("Current choices shouldn't get here!")

current_cache = cache.caches[current]
while not_terminated(current_cache)
perform_step!(current_cache)
@generated function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache{iip, N},
FastShortcutNonlinearPolyalgCache{iip, N}}) where {iip, N}
calls = [
quote
1 cache.current length(cache.caches) ||
error("Current choices shouldn't get here!")
end,
]

cache_syms = [gensym("cache") for i in 1:N]
sol_syms = [gensym("sol") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
$(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i]))
if SciMLBase.successful_retcode($(sol_syms[i]))
stats = $(sol_syms[i]).stats
u = $(sol_syms[i]).u
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
fu; retcode = ReturnCode.Success, stats,
original = $(sol_syms[i]))
end
cache.current = $(i + 1)
end
end)
end

return nothing
end

function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache,
FastShortcutNonlinearPolyalgCache})
current = cache.current
1 current length(cache.caches) || error("Current choices shouldn't get here!")

current_cache = cache.caches[current]
while current length(cache.caches) # && !all(terminated[current:end])
sol_tmp = solve!(current_cache)
SciMLBase.successful_retcode(sol_tmp) && break
current += 1
cache.current = current
current_cache = cache.caches[current]
resids = map(x -> Symbol("$(x)_resid"), cache_syms)
for (sym, resid) in zip(cache_syms, resids)
push!(calls, :($(resid) = get_fu($(sym))))
end
push!(calls,
quote
retcode = ReturnCode.MaxIters

if current length(cache.caches)
retcode = ReturnCode.Success

stats = cache.caches[current].stats
u = cache.caches[current].u
fu = get_fu(cache.caches[current])

return SciMLBase.build_solution(cache.caches[1].prob, cache.alg, u, fu;
retcode, stats)
else
retcode = ReturnCode.MaxIters
fus = tuple($(Tuple(resids)...))
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
u = cache.caches[idx].u

fus = get_fu.(cache.caches)
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
u = cache.caches[idx].u
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u,
fus[idx]; retcode, stats)
end)

return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
retcode, stats)
end
return Expr(:block, calls...)
end

function SciMLBase.reinit!(cache::Union{RobustMultiNewtonCache,
Expand Down
4 changes: 2 additions & 2 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
function perform_step!(cache::DFSaneCache{true})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)
f = (dx, x) -> cache.prob.f(dx, x, cache.p)

T = eltype(cache.uₙ)
n = cache.stats.nsteps
Expand Down Expand Up @@ -208,7 +208,7 @@ end
function perform_step!(cache::DFSaneCache{false})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)
f = x -> cache.prob.f(x, cache.p)

T = eltype(cache.uₙ)
n = cache.stats.nsteps
Expand Down

0 comments on commit 3724fbc

Please sign in to comment.