Skip to content

Commit

Permalink
make DFSane conform to the current code style
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2023
1 parent 3eab77d commit 03d1e85
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 126 deletions.
146 changes: 48 additions & 98 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,97 +57,51 @@ struct DFSane{T, F} <: AbstractNonlinearSolveAlgorithm
max_inner_iterations::Int
end

function DFSane(; σ_min = 1e-10,
σ_max = 1e+10,
σ_1 = 1.0,
M = 10,
γ = 1e-4,
τ_min = 0.1,
τ_max = 0.5,
n_exp = 2,
η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations = 1000)
return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min,
σ_max,
σ_1,
M,
γ,
τ_min,
τ_max,
n_exp,
η_strategy,
max_inner_iterations)
function DFSane(; σ_min = 1e-10, σ_max = 1e+10, σ_1 = 1.0, M = 10, γ = 1e-4, τ_min = 0.1,

Check warning on line 60 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L60

Added line #L60 was not covered by tests
τ_max = 0.5, n_exp = 2, η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations = 1000)
return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max,

Check warning on line 63 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L63

Added line #L63 was not covered by tests
n_exp, η_strategy, max_inner_iterations)
end
mutable struct DFSaneCache{iip, algType, uType, resType, T, pType,
INType,
tolType,
probType}
f::Function
alg::algType
uₙ::uType
uₙ₋₁::uType
fuₙ::resType
fuₙ₋₁::resType
𝒹::uType
::Vector{T}
f₍ₙₒᵣₘ₎ₙ₋₁::T
f₍ₙₒᵣₘ₎₀::T
M::Int
σₙ::T
σₘᵢₙ::T
σₘₐₓ::T
α₁::T
γ::T
τₘᵢₙ::T
τₘₐₓ::T
@concrete mutable struct DFSaneCache{iip}
f
alg
uₙ
uₙ₋₁
fuₙ
fuₙ₋₁
𝒹
f₍ₙₒᵣₘ₎ₙ₋₁
f₍ₙₒᵣₘ₎₀
M
σₙ
σₘᵢₙ
σₘₐₓ
α₁
γ
τₘᵢₙ
τₘₐₓ
nₑₓₚ::Int
p::pType
p
force_stop::Bool
maxiters::Int
internalnorm::INType
internalnorm
retcode::SciMLBase.ReturnCode.T
abstol::tolType
prob::probType
abstol
prob
stats::NLStats
function DFSaneCache{iip}(f::Function, alg::algType, uₙ::uType, uₙ₋₁::uType,
fuₙ::resType, fuₙ₋₁::resType, 𝒹::uType, ℋ::Vector{T},
f₍ₙₒᵣₘ₎ₙ₋₁::T, f₍ₙₒᵣₘ₎₀::T, M::Int, σₙ::T, σₘᵢₙ::T, σₘₐₓ::T,
α₁::T, γ::T, τₘᵢₙ::T, τₘₐₓ::T, nₑₓₚ::Int, p::pType,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType,
stats::NLStats) where {iip, algType, uType,
resType, T, pType, INType,
tolType,
probType
}
new{iip, algType, uType, resType, T, pType, INType, tolType,
probType
}(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, M, σₙ,
σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ,
τₘₐₓ, nₑₓₚ, p, force_stop, maxiters, internalnorm,
retcode,
abstol, prob, stats)
end
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane,
args...;
alias_u0 = false,
maxiters = 1000,
abstol = 1e-6,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
if alias_u0
uₙ = prob.u0
else
uₙ = deepcopy(prob.u0)
end
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;

Check warning on line 96 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L96

Added line #L96 was not covered by tests
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)

Check warning on line 99 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L99

Added line #L99 was not covered by tests

p = prob.p
T = eltype(uₙ)
σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σ_min), T(alg.σ_max), T(alg.γ), T(alg.τ_min),
T(alg.τ_max)
T(alg.τ_max)
α₁ = one(T)
γ = T(alg.γ)
f₍ₙₒᵣₘ₎ₙ₋₁ = α₁
Expand All @@ -169,10 +123,10 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane,

= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)
return DFSaneCache{iip}(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ,
τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, prob,
NLStats(1, 0, 0, 0, 0))
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ,
τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, prob,
NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::DFSaneCache{true})
Expand Down Expand Up @@ -202,10 +156,9 @@ function perform_step!(cache::DFSaneCache{true})

f₍ₙₒᵣₘ₎ₙ 𝒸 && break

α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ /
(f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),

Check warning on line 159 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L159

Added line #L159 was not covered by tests
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
@. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹

f(cache.fuₙ, cache.uₙ)
Expand All @@ -214,8 +167,8 @@ function perform_step!(cache::DFSaneCache{true})
f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break

α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₋,
τₘₐₓ * α₋)
τₘᵢₙ * α₋,
τₘₐₓ * α₋)

@. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹
f(cache.fuₙ, cache.uₙ)
Expand Down Expand Up @@ -279,10 +232,8 @@ function perform_step!(cache::DFSaneCache{false})

f₍ₙₒᵣₘ₎ₙ 𝒸 && break

α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ /
(f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),

Check warning on line 235 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L235

Added line #L235 was not covered by tests
τₘᵢₙ * α₊, τₘₐₓ * α₊)
cache.uₙ = @. cache.uₙ₋₁ - α₋ * cache.𝒹

cache.fuₙ = f(cache.uₙ)
Expand All @@ -291,8 +242,7 @@ function perform_step!(cache::DFSaneCache{false})
f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break

α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₋,
τₘₐₓ * α₋)
τₘᵢₙ * α₋, τₘₐₓ * α₋)

cache.uₙ = @. cache.uₙ₋₁ + α₊ * cache.𝒹
cache.fuₙ = f(cache.uₙ)
Expand Down Expand Up @@ -341,12 +291,12 @@ function SciMLBase.solve!(cache::DFSaneCache)
cache.retcode = ReturnCode.Success
end

SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ;
retcode = cache.retcode, stats = cache.stats)
return SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ;
retcode = cache.retcode, stats = cache.stats)
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
Expand Down
63 changes: 35 additions & 28 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,18 +391,17 @@ end
end
end


# --- DFSane tests ---

@testset "DFSane" begin
function benchmark_nlsolve_oop(f, u0, p=2.0)
function benchmark_nlsolve_oop(f, u0, p = 2.0)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, DFSane(), abstol=1e-9)
return solve(prob, DFSane(), abstol = 1e-9)
end

function benchmark_nlsolve_iip(f, u0, p=2.0)
function benchmark_nlsolve_iip(f, u0, p = 2.0)
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, DFSane(), abstol=1e-9)
return solve(prob, DFSane(), abstol = 1e-9)
end

u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand All @@ -413,7 +412,7 @@ end
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), DFSane(),
abstol=1e-9)
abstol = 1e-9)
@test (@ballocated solve!($cache)) < 200
end

Expand All @@ -424,11 +423,10 @@ end
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
DFSane(), abstol=1e-9)
DFSane(), abstol = 1e-9)
@test (@ballocated solve!($cache)) 64
end


@testset "[OOP] [Immutable AD]" begin
broken_forwarddiff = [1.6, 2.9, 3.0, 3.5, 4.0, 81.0]
for p in 1.1:0.1:100.0
Expand All @@ -437,14 +435,14 @@ end
if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken all(res .≈ sqrt(p))
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
elseif p in broken_forwarddiff
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
else
@test all(res .≈ sqrt(p))
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p)))
@SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p)))
end
end
end
Expand All @@ -456,12 +454,22 @@ end

if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken res sqrt(p)
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) 1 / (2 * sqrt(p))
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0,
p).u,
p)) 1 / (2 * sqrt(p))
elseif p in broken_forwarddiff
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) 1 / (2 * sqrt(p))
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0,
p).u,
p)) 1 / (2 * sqrt(p))
else
@test res sqrt(p)
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)), 1 / (2 * sqrt(p)))
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0,
p).u,
p)),
1 / (2 * sqrt(p)))
end
end
end
Expand All @@ -475,20 +483,19 @@ end
# Iterator interface
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
cache = init(probN, DFSane(); maxiters=100, abstol=1e-10)
cache = init(probN, DFSane(); maxiters = 100, abstol = 1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, iip ? [cache.uₙ[1]] : cache.uₙ; p=p)
reinit!(cache, iip ? [cache.uₙ[1]] : cache.uₙ; p = p)
sol = solve!(cache)
sols[i] = iip ? sol.u[1] : sol.u
end
return sols
end
p = range(0.01, 2, length=200)
p = range(0.01, 2, length = 200)
@test abs.(nlprob_iterator_interface(quadratic_f, p, Val(false))) sqrt.(p)
@test abs.(nlprob_iterator_interface(quadratic_f!, p, Val(true))) sqrt.(p)


# Test that `DFSane` passes a test that `NewtonRaphson` fails on.
@testset "Newton Raphson Fails" begin
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
Expand Down Expand Up @@ -518,18 +525,18 @@ end
η_strategy)
for options in list_of_options
local probN, sol, alg
alg = DFSane(σ_min=options[1],
σ_max=options[2],
σ_1=options[3],
M=options[4],
γ=options[5],
τ_min=options[6],
τ_max=options[7],
n_exp=options[8],
η_strategy=options[9])
alg = DFSane(σ_min = options[1],
σ_max = options[2],
σ_1 = options[3],
M = options[4],
γ = options[5],
τ_min = options[6],
τ_max = options[7],
n_exp = options[8],
η_strategy = options[9])

probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
sol = solve(probN, alg, abstol=1e-11)
sol = solve(probN, alg, abstol = 1e-11)
println(abs.(quadratic_f(sol.u, 2.0)))
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
end
Expand Down

0 comments on commit 03d1e85

Please sign in to comment.