Skip to content

Commit

Permalink
Make it generated
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2023
1 parent 7b8ca58 commit 7a9b556
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 47 deletions.
93 changes: 48 additions & 45 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,14 @@ end
end)
end

resids = map(x -> "$x.resid", sol_syms)
resids = map(x -> Symbol("$(x)_resid"), sol_syms)
for (sym, resid) in zip(sol_syms, resids)
push!(calls, :($(resid) = $(sym).resid))
end

push!(calls,
quote
resids = $(Tuple(resids))
resids = tuple($(Tuple(resids)...))
minfu, idx = findmin(DEFAULT_NORM, resids)

Check warning on line 196 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L195-L196

Added lines #L195 - L196 were not covered by tests
end)

Expand All @@ -198,8 +201,7 @@ end
quote
if idx == $i
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,

Check warning on line 203 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L202-L203

Added lines #L202 - L203 were not covered by tests
$(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats,
original = $(sol_syms[i]))
$(sol_syms[i]).resid; $(sol_syms[i]).retcode, $(sol_syms[i]).stats)
end
end)
end
Expand All @@ -210,53 +212,54 @@ end

## General shared polyalg functions

function perform_step!(cache::Union{RobustMultiNewtonCache,
FastShortcutNonlinearPolyalgCache})
current = cache.current
1 current length(cache.caches) || error("Current choices shouldn't get here!")

current_cache = cache.caches[current]
while not_terminated(current_cache)
perform_step!(current_cache)
@generated function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache{iip, N},
FastShortcutNonlinearPolyalgCache{iip, N}}) where {iip, N}
calls = [
quote
1 cache.current length(cache.caches) ||
error("Current choices shouldn't get here!")
end,
]

cache_syms = [gensym("cache") for i in 1:N]
sol_syms = [gensym("sol") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
$(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i]))
if SciMLBase.successful_retcode($(sol_syms[i]))
stats = $(sol_syms[i]).stats
u = $(sol_syms[i]).u
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
fu; retcode = ReturnCode.Success, stats,
original = $(sol_syms[i]))
end
cache.current = $(i + 1)

Check warning on line 240 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L240

Added line #L240 was not covered by tests
end
end)
end

return nothing
end

function SciMLBase.solve!(cache::Union{RobustMultiNewtonCache,
FastShortcutNonlinearPolyalgCache})
current = cache.current
1 current length(cache.caches) || error("Current choices shouldn't get here!")

current_cache = cache.caches[current]
while current length(cache.caches) # && !all(terminated[current:end])
sol_tmp = solve!(current_cache)
SciMLBase.successful_retcode(sol_tmp) && break
current += 1
cache.current = current
current_cache = cache.caches[current]
resids = map(x -> Symbol("$(x)_resid"), cache_syms)
for (sym, resid) in zip(cache_syms, resids)
push!(calls, :($(resid) = get_fu($(sym))))
end
push!(calls,
quote
retcode = ReturnCode.MaxIters

Check warning on line 251 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L251

Added line #L251 was not covered by tests

if current length(cache.caches)
retcode = ReturnCode.Success

stats = cache.caches[current].stats
u = cache.caches[current].u
fu = get_fu(cache.caches[current])

return SciMLBase.build_solution(cache.caches[1].prob, cache.alg, u, fu;
retcode, stats)
else
retcode = ReturnCode.MaxIters
fus = tuple($(Tuple(resids)...))
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
u = cache.caches[idx].u

Check warning on line 256 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L253-L256

Added lines #L253 - L256 were not covered by tests

fus = get_fu.(cache.caches)
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
u = cache.caches[idx].u
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u,

Check warning on line 258 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L258

Added line #L258 was not covered by tests
fus[idx]; retcode, stats)
end)

return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
retcode, stats)
end
return Expr(:block, calls...)
end

function SciMLBase.reinit!(cache::Union{RobustMultiNewtonCache,

Check warning on line 265 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L265

Added line #L265 was not covered by tests
Expand Down
4 changes: 2 additions & 2 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
function perform_step!(cache::DFSaneCache{true})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)
f = (dx, x) -> cache.prob.f(dx, x, cache.p)

T = eltype(cache.uₙ)
n = cache.stats.nsteps
Expand Down Expand Up @@ -208,7 +208,7 @@ end
function perform_step!(cache::DFSaneCache{false})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)
f = x -> cache.prob.f(x, cache.p)

T = eltype(cache.uₙ)
n = cache.stats.nsteps
Expand Down

0 comments on commit 7a9b556

Please sign in to comment.