Skip to content

Commit

Permalink
Fix 23 test problems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 27, 2023
1 parent fefe476 commit c428bd9
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 81 deletions.
5 changes: 2 additions & 3 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,

termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
termination_condition)

cache.abstol = abstol
Expand Down
27 changes: 9 additions & 18 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,18 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
𝒹, uₙ₋₁, fuₙ, fuₙ₋₁ = copy(uₙ), copy(uₙ), copy(uₙ), copy(uₙ)

if iip
# f = (dx, x) -> prob.f(dx, x, p)
# f(fuₙ₋₁, uₙ₋₁)
prob.f(fuₙ₋₁, uₙ₋₁, p)
else
# f = (x) -> prob.f(x, p)
fuₙ₋₁ = prob.f(uₙ₋₁, p) # f(uₙ₋₁)
fuₙ₋₁ = prob.f(uₙ₋₁, p)
end

f₍ₙₒᵣₘ₎ₙ₋₁ = norm(fuₙ₋₁)^nₑₓₚ
f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁

= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
T)
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
termination_condition, T)

mode = DiffEqBase.get_termination_mode(termination_condition)

Expand Down Expand Up @@ -167,14 +162,13 @@ function perform_step!(cache::DFSaneCache{true})

f(cache.fuₙ, cache.uₙ)
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
for _ in 1:(cache.alg.max_inner_iterations)
for jjj in 1:(cache.alg.max_inner_iterations)
𝒸 =+ η - γ * α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁

f₍ₙₒᵣₘ₎ₙ 𝒸 && break

α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
τₘᵢₙ * α₊, τₘₐₓ * α₊)
@. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹

f(cache.fuₙ, cache.uₙ)
Expand All @@ -183,8 +177,7 @@ function perform_step!(cache::DFSaneCache{true})
f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break

α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₋,
τₘₐₓ * α₋)
τₘᵢₙ * α₋, τₘₐₓ * α₋)

@. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹
f(cache.fuₙ, cache.uₙ)
Expand All @@ -207,7 +200,7 @@ function perform_step!(cache::DFSaneCache{true})
# Spectral parameter bounds check
if abs(cache.σₙ) > σₘₐₓ || abs(cache.σₙ) < σₘᵢₙ
test_norm = sqrt(sum(abs2, cache.fuₙ₋₁))
cache.σₙ = clamp(1.0 / test_norm, 1, 1e5)
cache.σₙ = clamp(T(1) / test_norm, T(1), T(1e5))
end

# Take step
Expand Down Expand Up @@ -283,7 +276,7 @@ function perform_step!(cache::DFSaneCache{false})
# Spectral parameter bounds check
if abs(cache.σₙ) > σₘₐₓ || abs(cache.σₙ) < σₘᵢₙ
test_norm = sqrt(sum(abs2, cache.fuₙ₋₁))
cache.σₙ = clamp(1.0 / test_norm, 1, 1e5)
cache.σₙ = clamp(T(1) / test_norm, T(1), T(1e5))
end

# Take step
Expand Down Expand Up @@ -337,9 +330,7 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
T = eltype(cache.uₙ)
cache.σₙ = T(cache.alg.σ_1)

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
termination_condition)

cache.abstol = abstol
Expand Down
16 changes: 5 additions & 11 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
nothing

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
reltol,
prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition)
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition)
end

function perform_step!(cache::GaussNewtonCache{true})
Expand All @@ -147,10 +145,7 @@ function perform_step!(cache::GaussNewtonCache{true})
@. u = u - du
f(cache.fu_new, u, p)

(termination_condition(cache.fu_new .- cache.fu1,
cache.u,
u_prev,
cache.abstol,
(termination_condition(cache.fu_new .- cache.fu1, cache.u, u_prev, cache.abstol,
cache.reltol) ||
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
(cache.force_stop = true)
Expand Down Expand Up @@ -217,9 +212,8 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,

termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,

Check warning on line 216 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L216

Added line #L216 was not covered by tests
termination_condition)

cache.abstol = abstol
Expand Down
5 changes: 2 additions & 3 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,9 @@ function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = ca
cache.fu = cache.f(cache.u, p)
end

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
Expand Down
24 changes: 2 additions & 22 deletions src/pseudotransient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
return PseudoTransient{_unwrap_val(concrete_jac)}(ad, linsolve, precs, alpha_initial)
end

@concrete mutable struct PseudoTransientCache{iip}
@concrete mutable struct PseudoTransientCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
Expand All @@ -78,8 +78,6 @@ end
tc_storage
end

isinplace(::PseudoTransientCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransient,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
Expand Down Expand Up @@ -174,22 +172,6 @@ function perform_step!(cache::PseudoTransientCache{false})
return nothing
end

function SciMLBase.solve!(cache::PseudoTransientCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end

function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p,
alpha_new,
abstol = cache.abstol, reltol = cache.reltol,
Expand All @@ -205,9 +187,7 @@ function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = c
cache.fu1 = cache.f(cache.u, p)
end

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
termination_condition)

cache.alpha = convert(eltype(cache.u), alpha_new)
Expand Down
5 changes: 2 additions & 3 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,9 @@ function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cac
cache.fu1 = cache.f(cache.u, p)
end

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
Expand Down
5 changes: 2 additions & 3 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,8 @@ function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,

termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
termination_condition)

cache.abstol = abstol
Expand Down
36 changes: 19 additions & 17 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,16 @@ end
function _init_termination_elements(abstol, reltol, termination_condition,
::Type{T}; mode = NLSolveTerminationMode.AbsNorm) where {T}
if termination_condition !== nothing
abstol !== nothing ?
(abstol != termination_condition.abstol ?
error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") :
nothing) : nothing
reltol !== nothing ?
(reltol != termination_condition.abstol ?
error("Incompatible relative tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.") :
nothing) : nothing
if abstol !== nothing && abstol != termination_condition.abstol
error("Incompatible absolute tolerances found. The tolerances supplied as the \

Check warning on line 222 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L222

Added line #L222 was not covered by tests
keyword argument and the one supplied in the termination condition should \
be same.")
end
if reltol !== nothing && reltol != termination_condition.reltol
error("Incompatible relative tolerances found. The tolerances supplied as the \

Check warning on line 227 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L227

Added line #L227 was not covered by tests
keyword argument and the one supplied in the termination condition should \
be same.")
end
abstol = _get_tolerance(abstol, termination_condition.abstol, T)
reltol = _get_tolerance(reltol, termination_condition.reltol, T)
return abstol, reltol, termination_condition
Expand All @@ -239,18 +241,18 @@ end

function _get_reinit_termination_condition(cache, abstol, reltol, termination_condition)
if termination_condition != cache.termination_condition
if abstol != cache.abstol
if abstol != termination_condition.abstol
error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.")
end
if abstol != cache.abstol && abstol != termination_condition.abstol
error("Incompatible absolute tolerances found. The tolerances supplied as the \

Check warning on line 245 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L244-L245

Added lines #L244 - L245 were not covered by tests
keyword argument and the one supplied in the termination condition \
should be same.")
end

if reltol != cache.reltol
if reltol != termination_condition.reltol
error("Incompatible absolute tolerances found. The tolerances supplied as the keyword argument and the one supplied in the termination condition should be same.")
end
if reltol != cache.reltol && reltol != termination_condition.reltol
error("Incompatible absolute tolerances found. The tolerances supplied as the \

Check warning on line 251 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L250-L251

Added lines #L250 - L251 were not covered by tests
keyword argument and the one supplied in the termination condition \
should be same.")
end
termination_condition
return termination_condition

Check warning on line 255 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L255

Added line #L255 was not covered by tests
else
# Build the termination_condition with new abstol and reltol
return NLSolveTerminationCondition{
Expand Down
3 changes: 2 additions & 1 deletion test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-4)
@testset "$idx: $(dict["title"])" begin
for alg in alg_ops
try
sol = solve(nlprob, alg, abstol = 1e-18, reltol = 1e-18)
sol = solve(nlprob, alg)
problem(res, sol.u, nothing)

broken = idx in broken_tests[alg] ? true : false
@test norm(res)ϵ broken=broken
catch
Expand Down

0 comments on commit c428bd9

Please sign in to comment.