diff --git a/src/dfsane.jl b/src/dfsane.jl index 47cdbbfd3..a77703906 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -68,21 +68,21 @@ function DFSane(; σ_min = 1e-10, η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2, max_inner_iterations = 1000) return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min, - σ_max, - σ_1, - M, - γ, - τ_min, - τ_max, - n_exp, - η_strategy, - max_inner_iterations) + σ_max, + σ_1, + M, + γ, + τ_min, + τ_max, + n_exp, + η_strategy, + max_inner_iterations) end -mutable struct DFSaneCache{iip, fType, algType, uType, resType, T, pType, +mutable struct DFSaneCache{iip, algType, uType, resType, T, pType, INType, tolType, probType} - f::fType + f::Function alg::algType uₙ::uType uₙ₋₁::uType @@ -109,19 +109,19 @@ mutable struct DFSaneCache{iip, fType, algType, uType, resType, T, pType, abstol::tolType prob::probType stats::NLStats - function DFSaneCache{iip}(f::fType, alg::algType, uₙ::uType, uₙ₋₁::uType, + function DFSaneCache{iip}(f::Function, alg::algType, uₙ::uType, uₙ₋₁::uType, fuₙ::resType, fuₙ₋₁::resType, 𝒹::uType, ℋ::Vector{T}, f₍ₙₒᵣₘ₎ₙ₋₁::T, f₍ₙₒᵣₘ₎₀::T, M::Int, σₙ::T, σₘᵢₙ::T, σₘₐₓ::T, α₁::T, γ::T, τₘᵢₙ::T, τₘₐₓ::T, nₑₓₚ::Int, p::pType, force_stop::Bool, maxiters::Int, internalnorm::INType, retcode::SciMLBase.ReturnCode.T, abstol::tolType, prob::probType, - stats::NLStats) where {iip, fType, algType, uType, + stats::NLStats) where {iip, algType, uType, resType, T, pType, INType, tolType, probType } - new{iip, fType, algType, uType, resType, T, pType, INType, tolType, + new{iip, algType, uType, resType, T, pType, INType, tolType, probType }(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, @@ -146,7 +146,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, p = prob.p T = eltype(uₙ) - σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σ_min), T(alg.σ_max), T(alg.γ), T(alg.τ_min), T(alg.τ_max) + σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σ_min), T(alg.σ_max), T(alg.γ), T(alg.τ_min), + T(alg.τ_max) α₁ = one(T) γ = T(alg.γ) f₍ₙₒᵣₘ₎ₙ₋₁ = α₁ @@ -262,16 +263,16 @@ function perform_step!(cache::DFSaneCache{false}) σₙ = sign(σₙ) * clamp(abs(σₙ), σₘᵢₙ, σₘₐₓ) # Line search direction - @. cache.𝒹 = -σₙ * cache.fuₙ₋₁ + cache.𝒹 = -σₙ * cache.fuₙ₋₁ - η = alg.ηₛ(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁) + η = alg.η_strategy(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁) f̄ = maximum(cache.ℋ) α₊ = α₁ α₋ = α₁ - @. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹 + cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹 - cache.fuₙ .= f(cache.uₙ) + cache.fuₙ = f(cache.uₙ) f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ for _ in 1:(cache.alg.max_inner_iterations) 𝒸 = f̄ + η - γ * α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ @@ -282,9 +283,9 @@ function perform_step!(cache::DFSaneCache{false}) (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), τₘᵢₙ * α₊, τₘₐₓ * α₊) - @. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹 + cache.uₙ = @. cache.uₙ₋₁ - α₋ * cache.𝒹 - cache.fuₙ .= f(cache.uₙ) + cache.fuₙ = f(cache.uₙ) f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break @@ -293,8 +294,8 @@ function perform_step!(cache::DFSaneCache{false}) τₘᵢₙ * α₋, τₘₐₓ * α₋) - @. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹 - cache.fuₙ .= f(cache.uₙ) + cache.uₙ = @. cache.uₙ₋₁ + α₊ * cache.𝒹 + cache.fuₙ = f(cache.uₙ) f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ end @@ -303,11 +304,11 @@ function perform_step!(cache::DFSaneCache{false}) end # Update spectral parameter - @. cache.uₙ₋₁ = cache.uₙ - cache.uₙ₋₁ - @. cache.fuₙ₋₁ = cache.fuₙ - cache.fuₙ₋₁ + cache.uₙ₋₁ = @. cache.uₙ - cache.uₙ₋₁ + cache.fuₙ₋₁ = @. cache.fuₙ - cache.fuₙ₋₁ α₊ = sum(abs2, cache.uₙ₋₁) - @. cache.uₙ₋₁ = cache.uₙ₋₁ * cache.fuₙ₋₁ + cache.uₙ₋₁ = @. cache.uₙ₋₁ * cache.fuₙ₋₁ α₋ = sum(cache.uₙ₋₁) cache.σₙ = α₊ / α₋ @@ -318,8 +319,8 @@ function perform_step!(cache::DFSaneCache{false}) end # Take step - @. cache.uₙ₋₁ = cache.uₙ - @. cache.fuₙ₋₁ = cache.fuₙ + cache.uₙ₋₁ = cache.uₙ + cache.fuₙ₋₁ = cache.fuₙ cache.f₍ₙₒᵣₘ₎ₙ₋₁ = f₍ₙₒᵣₘ₎ₙ # Update history @@ -345,32 +346,34 @@ function SciMLBase.solve!(cache::DFSaneCache) end function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} - cache.p = p - if iip - recursivecopy!(cache.uₙ, u0) - recursivecopy!(cache.uₙ₋₁, u0) - cache.f(cache.fuₙ, cache.uₙ, p) - cache.f(cache.fuₙ₋₁, cache.uₙ, p) - else - cache.uₙ = u0 - cache.uₙ₋₁ = u0 - cache.fuₙ = cache.f(cache.uₙ, p) - cache.fuₙ₋₁ = cache.f(cache.uₙ, p) - end - - cache.f₍ₙₒᵣₘ₎ₙ₋₁ = norm(fuₙ₋₁)^nₑₓₚ - cache.f₍ₙₒᵣₘ₎₀ = cache.f₍ₙₒᵣₘ₎ₙ₋₁ - fill!(cache.ℋ, cache.f₍ₙₒᵣₘ₎ₙ₋₁, cache.M) - - T = eltype(cache.uₙ) - cache.σₙ = T(cache.alg.σ_1) - - cache.abstol = abstol - cache.maxiters = maxiters - cache.stats.nf = 1 - cache.stats.nsteps = 1 - cache.force_stop = false - cache.retcode = ReturnCode.Default - return cache + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + cache.p = p + if iip + recursivecopy!(cache.uₙ, u0) + recursivecopy!(cache.uₙ₋₁, u0) + cache.f = (dx, x) -> cache.prob.f(dx, x, p) + cache.f(cache.fuₙ, cache.uₙ) + cache.f(cache.fuₙ₋₁, cache.uₙ) + else + cache.uₙ = u0 + cache.uₙ₋₁ = u0 + cache.f = (x) -> cache.prob.f(x, p) + cache.fuₙ = cache.f(cache.uₙ) + cache.fuₙ₋₁ = cache.f(cache.uₙ) + end + + cache.f₍ₙₒᵣₘ₎ₙ₋₁ = norm(cache.fuₙ₋₁)^cache.nₑₓₚ + cache.f₍ₙₒᵣₘ₎₀ = cache.f₍ₙₒᵣₘ₎ₙ₋₁ + fill!(cache.ℋ, cache.f₍ₙₒᵣₘ₎ₙ₋₁) + + T = eltype(cache.uₙ) + cache.σₙ = T(cache.alg.σ_1) + + cache.abstol = abstol + cache.maxiters = maxiters + cache.stats.nf = 1 + cache.stats.nsteps = 1 + cache.force_stop = false + cache.retcode = ReturnCode.Default + return cache end diff --git a/test/basictests.jl b/test/basictests.jl index 1e1ded563..0ad70d8eb 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -390,3 +390,148 @@ end end end end + + +# --- DFSane tests --- + +@testset "DFSane" begin + function benchmark_nlsolve_oop(f, u0, p=2.0) + prob = NonlinearProblem{false}(f, u0, p) + return solve(prob, DFSane(), abstol=1e-9) + end + + function benchmark_nlsolve_iip(f, u0, p=2.0) + prob = NonlinearProblem{true}(f, u0, p) + return solve(prob, DFSane(), abstol=1e-9) + end + + u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) + + @testset "[OOP] u0: $(typeof(u0))" for u0 in u0s + sol = benchmark_nlsolve_oop(quadratic_f, u0) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + + cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), DFSane(), + abstol=1e-9) + @test (@ballocated solve!($cache)) < 200 + end + + @testset "[IIP] u0: $(typeof(u0))" for u0 in ([ + 1.0, 1.0],) + sol = benchmark_nlsolve_iip(quadratic_f!, u0) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + + cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), + DFSane(), abstol=1e-9) + @test (@ballocated solve!($cache)) ≤ 64 + end + + + @testset "[OOP] [Immutable AD]" begin + broken_forwarddiff = [1.6, 2.9, 3.0, 3.5, 4.0, 81.0] + for p in 1.1:0.1:100.0 + res = abs.(benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p).u) + + if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res) + @test_broken all(res .≈ sqrt(p)) + @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) + elseif p in broken_forwarddiff + @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) + else + @test all(res .≈ sqrt(p)) + @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p))) + end + end + end + + @testset "[OOP] [Scalar AD]" begin + broken_forwarddiff = [1.6, 2.9, 3.0, 3.5, 4.0, 81.0] + for p in 1.1:0.1:100.0 + res = abs(benchmark_nlsolve_oop(quadratic_f, 1.0, p).u) + + if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res) + @test_broken res ≈ sqrt(p) + @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) ≈ 1 / (2 * sqrt(p)) + elseif p in broken_forwarddiff + @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) ≈ 1 / (2 * sqrt(p)) + else + @test res ≈ sqrt(p) + @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)), 1 / (2 * sqrt(p))) + end + end + end + + t = (p) -> [sqrt(p[2] / p[1])] + p = [0.9, 50.0] + @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) + @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], + p) ≈ ForwardDiff.jacobian(t, p) + + # Iterator interface + function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip} + probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin]) + cache = init(probN, DFSane(); maxiters=100, abstol=1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, iip ? [cache.uₙ[1]] : cache.uₙ; p=p) + sol = solve!(cache) + sols[i] = iip ? sol.u[1] : sol.u + end + return sols + end + p = range(0.01, 2, length=200) + @test abs.(nlprob_iterator_interface(quadratic_f, p, Val(false))) ≈ sqrt.(p) + @test abs.(nlprob_iterator_interface(quadratic_f!, p, Val(true))) ≈ sqrt.(p) + + + # Test that `DFSane` passes a test that `NewtonRaphson` fails on. + @testset "Newton Raphson Fails" begin + u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0] + p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + sol = benchmark_nlsolve_oop(newton_fails, u0, p) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(newton_fails(sol.u, p)) .< 1e-9) + end + + # Test kwargs in `DFSane` + @testset "Keyword Arguments" begin + σ_min = [1e-10, 1e-5, 1e-4] + σ_max = [1e10, 1e5, 1e4] + σ_1 = [1.0, 0.5, 2.0] + M = [10, 1, 100] + γ = [1e-4, 1e-3, 1e-5] + τ_min = [0.1, 0.2, 0.3] + τ_max = [0.5, 0.8, 0.9] + nexp = [2, 1, 2] + η_strategy = [ + (f_1, k, x, F) -> f_1 / k^2, + (f_1, k, x, F) -> f_1 / k^3, + (f_1, k, x, F) -> f_1 / k^4, + ] + + list_of_options = zip(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp, + η_strategy) + for options in list_of_options + local probN, sol, alg + alg = DFSane(σ_min=options[1], + σ_max=options[2], + σ_1=options[3], + M=options[4], + γ=options[5], + τ_min=options[6], + τ_max=options[7], + n_exp=options[8], + η_strategy=options[9]) + + probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0) + sol = solve(probN, alg, abstol=1e-11) + println(abs.(quadratic_f(sol.u, 2.0))) + @test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10) + end + end +end