From 03d1e85051f1b39234563ab5fd8b6542ec7d7500 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Oct 2023 16:05:06 -0400 Subject: [PATCH] make DFSane conform to the current code style --- src/dfsane.jl | 146 +++++++++++++++------------------------------ test/basictests.jl | 63 ++++++++++--------- 2 files changed, 83 insertions(+), 126 deletions(-) diff --git a/src/dfsane.jl b/src/dfsane.jl index a77703906..2a3319796 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -57,97 +57,51 @@ struct DFSane{T, F} <: AbstractNonlinearSolveAlgorithm max_inner_iterations::Int end -function DFSane(; σ_min = 1e-10, - σ_max = 1e+10, - σ_1 = 1.0, - M = 10, - γ = 1e-4, - τ_min = 0.1, - τ_max = 0.5, - n_exp = 2, - η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2, - max_inner_iterations = 1000) - return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min, - σ_max, - σ_1, - M, - γ, - τ_min, - τ_max, - n_exp, - η_strategy, - max_inner_iterations) +function DFSane(; σ_min = 1e-10, σ_max = 1e+10, σ_1 = 1.0, M = 10, γ = 1e-4, τ_min = 0.1, + τ_max = 0.5, n_exp = 2, η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2, + max_inner_iterations = 1000) + return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, + n_exp, η_strategy, max_inner_iterations) end -mutable struct DFSaneCache{iip, algType, uType, resType, T, pType, - INType, - tolType, - probType} - f::Function - alg::algType - uₙ::uType - uₙ₋₁::uType - fuₙ::resType - fuₙ₋₁::resType - 𝒹::uType - ℋ::Vector{T} - f₍ₙₒᵣₘ₎ₙ₋₁::T - f₍ₙₒᵣₘ₎₀::T - M::Int - σₙ::T - σₘᵢₙ::T - σₘₐₓ::T - α₁::T - γ::T - τₘᵢₙ::T - τₘₐₓ::T +@concrete mutable struct DFSaneCache{iip} + f + alg + uₙ + uₙ₋₁ + fuₙ + fuₙ₋₁ + 𝒹 + ℋ + f₍ₙₒᵣₘ₎ₙ₋₁ + f₍ₙₒᵣₘ₎₀ + M + σₙ + σₘᵢₙ + σₘₐₓ + α₁ + γ + τₘᵢₙ + τₘₐₓ nₑₓₚ::Int - p::pType + p force_stop::Bool maxiters::Int - internalnorm::INType + internalnorm retcode::SciMLBase.ReturnCode.T - abstol::tolType - prob::probType + abstol + prob stats::NLStats - function DFSaneCache{iip}(f::Function, alg::algType, uₙ::uType, uₙ₋₁::uType, - fuₙ::resType, fuₙ₋₁::resType, 𝒹::uType, ℋ::Vector{T}, - f₍ₙₒᵣₘ₎ₙ₋₁::T, f₍ₙₒᵣₘ₎₀::T, M::Int, σₙ::T, σₘᵢₙ::T, σₘₐₓ::T, - α₁::T, γ::T, τₘᵢₙ::T, τₘₐₓ::T, nₑₓₚ::Int, p::pType, - force_stop::Bool, maxiters::Int, internalnorm::INType, - retcode::SciMLBase.ReturnCode.T, abstol::tolType, - prob::probType, - stats::NLStats) where {iip, algType, uType, - resType, T, pType, INType, - tolType, - probType - } - new{iip, algType, uType, resType, T, pType, INType, tolType, - probType - }(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, M, σₙ, - σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, - τₘₐₓ, nₑₓₚ, p, force_stop, maxiters, internalnorm, - retcode, - abstol, prob, stats) - end end -function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, - args...; - alias_u0 = false, - maxiters = 1000, - abstol = 1e-6, - internalnorm = DEFAULT_NORM, - kwargs...) where {uType, iip} - if alias_u0 - uₙ = prob.u0 - else - uₙ = deepcopy(prob.u0) - end +function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...; + alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + kwargs...) where {uType, iip} + uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0) p = prob.p T = eltype(uₙ) σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σ_min), T(alg.σ_max), T(alg.γ), T(alg.τ_min), - T(alg.τ_max) + T(alg.τ_max) α₁ = one(T) γ = T(alg.γ) f₍ₙₒᵣₘ₎ₙ₋₁ = α₁ @@ -169,10 +123,10 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, ℋ = fill(f₍ₙₒᵣₘ₎ₙ₋₁, M) return DFSaneCache{iip}(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, - M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, - τₘₐₓ, nₑₓₚ, p, false, maxiters, - internalnorm, ReturnCode.Default, abstol, prob, - NLStats(1, 0, 0, 0, 0)) + M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, + τₘₐₓ, nₑₓₚ, p, false, maxiters, + internalnorm, ReturnCode.Default, abstol, prob, + NLStats(1, 0, 0, 0, 0)) end function perform_step!(cache::DFSaneCache{true}) @@ -202,10 +156,9 @@ function perform_step!(cache::DFSaneCache{true}) f₍ₙₒᵣₘ₎ₙ ≤ 𝒸 && break - α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / - (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), - τₘᵢₙ * α₊, - τₘₐₓ * α₊) + α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), + τₘᵢₙ * α₊, + τₘₐₓ * α₊) @. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹 f(cache.fuₙ, cache.uₙ) @@ -214,8 +167,8 @@ function perform_step!(cache::DFSaneCache{true}) f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), - τₘᵢₙ * α₋, - τₘₐₓ * α₋) + τₘᵢₙ * α₋, + τₘₐₓ * α₋) @. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹 f(cache.fuₙ, cache.uₙ) @@ -279,10 +232,8 @@ function perform_step!(cache::DFSaneCache{false}) f₍ₙₒᵣₘ₎ₙ ≤ 𝒸 && break - α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / - (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), - τₘᵢₙ * α₊, - τₘₐₓ * α₊) + α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), + τₘᵢₙ * α₊, τₘₐₓ * α₊) cache.uₙ = @. cache.uₙ₋₁ - α₋ * cache.𝒹 cache.fuₙ = f(cache.uₙ) @@ -291,8 +242,7 @@ function perform_step!(cache::DFSaneCache{false}) f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), - τₘᵢₙ * α₋, - τₘₐₓ * α₋) + τₘᵢₙ * α₋, τₘₐₓ * α₋) cache.uₙ = @. cache.uₙ₋₁ + α₊ * cache.𝒹 cache.fuₙ = f(cache.uₙ) @@ -341,12 +291,12 @@ function SciMLBase.solve!(cache::DFSaneCache) cache.retcode = ReturnCode.Success end - SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ; - retcode = cache.retcode, stats = cache.stats) + return SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ; + retcode = cache.retcode, stats = cache.stats) end function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.uₙ, u0) diff --git a/test/basictests.jl b/test/basictests.jl index 0ad70d8eb..f2dba1914 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -391,18 +391,17 @@ end end end - # --- DFSane tests --- @testset "DFSane" begin - function benchmark_nlsolve_oop(f, u0, p=2.0) + function benchmark_nlsolve_oop(f, u0, p = 2.0) prob = NonlinearProblem{false}(f, u0, p) - return solve(prob, DFSane(), abstol=1e-9) + return solve(prob, DFSane(), abstol = 1e-9) end - function benchmark_nlsolve_iip(f, u0, p=2.0) + function benchmark_nlsolve_iip(f, u0, p = 2.0) prob = NonlinearProblem{true}(f, u0, p) - return solve(prob, DFSane(), abstol=1e-9) + return solve(prob, DFSane(), abstol = 1e-9) end u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) @@ -413,7 +412,7 @@ end @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), DFSane(), - abstol=1e-9) + abstol = 1e-9) @test (@ballocated solve!($cache)) < 200 end @@ -424,11 +423,10 @@ end @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), - DFSane(), abstol=1e-9) + DFSane(), abstol = 1e-9) @test (@ballocated solve!($cache)) ≤ 64 end - @testset "[OOP] [Immutable AD]" begin broken_forwarddiff = [1.6, 2.9, 3.0, 3.5, 4.0, 81.0] for p in 1.1:0.1:100.0 @@ -437,14 +435,14 @@ end if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res) @test_broken all(res .≈ sqrt(p)) @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) + @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) elseif p in broken_forwarddiff @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) + @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) else @test all(res .≈ sqrt(p)) @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p))) + @SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p))) end end end @@ -456,12 +454,22 @@ end if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res) @test_broken res ≈ sqrt(p) - @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) ≈ 1 / (2 * sqrt(p)) + @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + 1.0, + p).u, + p)) ≈ 1 / (2 * sqrt(p)) elseif p in broken_forwarddiff - @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) ≈ 1 / (2 * sqrt(p)) + @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + 1.0, + p).u, + p)) ≈ 1 / (2 * sqrt(p)) else @test res ≈ sqrt(p) - @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)), 1 / (2 * sqrt(p))) + @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + 1.0, + p).u, + p)), + 1 / (2 * sqrt(p))) end end end @@ -475,20 +483,19 @@ end # Iterator interface function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip} probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin]) - cache = init(probN, DFSane(); maxiters=100, abstol=1e-10) + cache = init(probN, DFSane(); maxiters = 100, abstol = 1e-10) sols = zeros(length(p_range)) for (i, p) in enumerate(p_range) - reinit!(cache, iip ? [cache.uₙ[1]] : cache.uₙ; p=p) + reinit!(cache, iip ? [cache.uₙ[1]] : cache.uₙ; p = p) sol = solve!(cache) sols[i] = iip ? sol.u[1] : sol.u end return sols end - p = range(0.01, 2, length=200) + p = range(0.01, 2, length = 200) @test abs.(nlprob_iterator_interface(quadratic_f, p, Val(false))) ≈ sqrt.(p) @test abs.(nlprob_iterator_interface(quadratic_f!, p, Val(true))) ≈ sqrt.(p) - # Test that `DFSane` passes a test that `NewtonRaphson` fails on. @testset "Newton Raphson Fails" begin u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0] @@ -518,18 +525,18 @@ end η_strategy) for options in list_of_options local probN, sol, alg - alg = DFSane(σ_min=options[1], - σ_max=options[2], - σ_1=options[3], - M=options[4], - γ=options[5], - τ_min=options[6], - τ_max=options[7], - n_exp=options[8], - η_strategy=options[9]) + alg = DFSane(σ_min = options[1], + σ_max = options[2], + σ_1 = options[3], + M = options[4], + γ = options[5], + τ_min = options[6], + τ_max = options[7], + n_exp = options[8], + η_strategy = options[9]) probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0) - sol = solve(probN, alg, abstol=1e-11) + sol = solve(probN, alg, abstol = 1e-11) println(abs.(quadratic_f(sol.u, 2.0))) @test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10) end