Skip to content

Commit

Permalink
make DFSane conform to the current code style
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2023
1 parent 3eab77d commit 7b8ca58
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 141 deletions.
175 changes: 62 additions & 113 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,97 +57,51 @@ struct DFSane{T, F} <: AbstractNonlinearSolveAlgorithm
max_inner_iterations::Int
end

function DFSane(; σ_min = 1e-10,
σ_max = 1e+10,
σ_1 = 1.0,
M = 10,
γ = 1e-4,
τ_min = 0.1,
τ_max = 0.5,
n_exp = 2,
η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations = 1000)
return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min,
σ_max,
σ_1,
M,
γ,
τ_min,
τ_max,
n_exp,
η_strategy,
max_inner_iterations)
function DFSane(; σ_min = 1e-10, σ_max = 1e+10, σ_1 = 1.0, M = 10, γ = 1e-4, τ_min = 0.1,

Check warning on line 60 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L60

Added line #L60 was not covered by tests
τ_max = 0.5, n_exp = 2, η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations = 1000)
return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max,

Check warning on line 63 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L63

Added line #L63 was not covered by tests
n_exp, η_strategy, max_inner_iterations)
end
mutable struct DFSaneCache{iip, algType, uType, resType, T, pType,
INType,
tolType,
probType}
f::Function
alg::algType
uₙ::uType
uₙ₋₁::uType
fuₙ::resType
fuₙ₋₁::resType
𝒹::uType
::Vector{T}
f₍ₙₒᵣₘ₎ₙ₋₁::T
f₍ₙₒᵣₘ₎₀::T
M::Int
σₙ::T
σₘᵢₙ::T
σₘₐₓ::T
α₁::T
γ::T
τₘᵢₙ::T
τₘₐₓ::T

@concrete mutable struct DFSaneCache{iip}
alg
uₙ
uₙ₋₁
fuₙ
fuₙ₋₁
𝒹
f₍ₙₒᵣₘ₎ₙ₋₁
f₍ₙₒᵣₘ₎₀
M
σₙ
σₘᵢₙ
σₘₐₓ
α₁
γ
τₘᵢₙ
τₘₐₓ
nₑₓₚ::Int
p::pType
p
force_stop::Bool
maxiters::Int
internalnorm::INType
internalnorm
retcode::SciMLBase.ReturnCode.T
abstol::tolType
prob::probType
abstol
prob
stats::NLStats
function DFSaneCache{iip}(f::Function, alg::algType, uₙ::uType, uₙ₋₁::uType,
fuₙ::resType, fuₙ₋₁::resType, 𝒹::uType, ℋ::Vector{T},
f₍ₙₒᵣₘ₎ₙ₋₁::T, f₍ₙₒᵣₘ₎₀::T, M::Int, σₙ::T, σₘᵢₙ::T, σₘₐₓ::T,
α₁::T, γ::T, τₘᵢₙ::T, τₘₐₓ::T, nₑₓₚ::Int, p::pType,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType,
stats::NLStats) where {iip, algType, uType,
resType, T, pType, INType,
tolType,
probType
}
new{iip, algType, uType, resType, T, pType, INType, tolType,
probType
}(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, M, σₙ,
σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ,
τₘₐₓ, nₑₓₚ, p, force_stop, maxiters, internalnorm,
retcode,
abstol, prob, stats)
end
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane,
args...;
alias_u0 = false,
maxiters = 1000,
abstol = 1e-6,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
if alias_u0
uₙ = prob.u0
else
uₙ = deepcopy(prob.u0)
end
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;

Check warning on line 96 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L96

Added line #L96 was not covered by tests
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)

Check warning on line 99 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L99

Added line #L99 was not covered by tests

p = prob.p
T = eltype(uₙ)
σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σ_min), T(alg.σ_max), T(alg.γ), T(alg.τ_min),
T(alg.τ_max)
T(alg.τ_max)
α₁ = one(T)
γ = T(alg.γ)
f₍ₙₒᵣₘ₎ₙ₋₁ = α₁
Expand All @@ -157,27 +111,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane,
𝒹, uₙ₋₁, fuₙ, fuₙ₋₁ = copy(uₙ), copy(uₙ), copy(uₙ), copy(uₙ)

if iip
f = (dx, x) -> prob.f(dx, x, p)
f(fuₙ₋₁, uₙ₋₁)
# f = (dx, x) -> prob.f(dx, x, p)
# f(fuₙ₋₁, uₙ₋₁)
prob.f(fuₙ₋₁, uₙ₋₁, p)

Check warning on line 116 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L116

Added line #L116 was not covered by tests
else
f = (x) -> prob.f(x, p)
fuₙ₋₁ = f(uₙ₋₁)
# f = (x) -> prob.f(x, p)
fuₙ₋₁ = prob.f(uₙ₋₁, p) # f(uₙ₋₁)

Check warning on line 119 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L119

Added line #L119 was not covered by tests
end

f₍ₙₒᵣₘ₎ₙ₋₁ = norm(fuₙ₋₁)^nₑₓₚ
f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁

= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)
return DFSaneCache{iip}(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ,
τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, prob,
NLStats(1, 0, 0, 0, 0))
return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,

Check warning on line 126 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L126

Added line #L126 was not covered by tests
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::DFSaneCache{true})
@unpack f, alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

Check warning on line 132 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L132

Added line #L132 was not covered by tests

f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)

Check warning on line 134 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L134

Added line #L134 was not covered by tests

T = eltype(cache.uₙ)
n = cache.stats.nsteps
Expand All @@ -202,10 +156,9 @@ function perform_step!(cache::DFSaneCache{true})

f₍ₙₒᵣₘ₎ₙ 𝒸 && break

α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ /
(f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),

Check warning on line 159 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L159

Added line #L159 was not covered by tests
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
@. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹

f(cache.fuₙ, cache.uₙ)
Expand All @@ -214,8 +167,8 @@ function perform_step!(cache::DFSaneCache{true})
f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break

α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₋,
τₘₐₓ * α₋)
τₘᵢₙ * α₋,
τₘₐₓ * α₋)

@. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹
f(cache.fuₙ, cache.uₙ)
Expand Down Expand Up @@ -253,8 +206,9 @@ function perform_step!(cache::DFSaneCache{true})
end

function perform_step!(cache::DFSaneCache{false})
@unpack f, alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

Check warning on line 209 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L209

Added line #L209 was not covered by tests

f = iip ? (dx, x) -> cache.prob.f(dx, x, cache.p) : (x) -> cache.prob.f(x, cache.p)

Check warning on line 211 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L211

Added line #L211 was not covered by tests

T = eltype(cache.uₙ)
n = cache.stats.nsteps
Expand All @@ -279,10 +233,8 @@ function perform_step!(cache::DFSaneCache{false})

f₍ₙₒᵣₘ₎ₙ 𝒸 && break

α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ /
(f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),

Check warning on line 236 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L236

Added line #L236 was not covered by tests
τₘᵢₙ * α₊, τₘₐₓ * α₊)
cache.uₙ = @. cache.uₙ₋₁ - α₋ * cache.𝒹

cache.fuₙ = f(cache.uₙ)
Expand All @@ -291,8 +243,7 @@ function perform_step!(cache::DFSaneCache{false})
f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break

α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₋,
τₘₐₓ * α₋)
τₘᵢₙ * α₋, τₘₐₓ * α₋)

cache.uₙ = @. cache.uₙ₋₁ + α₊ * cache.𝒹
cache.fuₙ = f(cache.uₙ)
Expand Down Expand Up @@ -341,25 +292,23 @@ function SciMLBase.solve!(cache::DFSaneCache)
cache.retcode = ReturnCode.Success
end

SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ;
retcode = cache.retcode, stats = cache.stats)
return SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ;
retcode = cache.retcode, stats = cache.stats)
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
recursivecopy!(cache.uₙ₋₁, u0)
cache.f = (dx, x) -> cache.prob.f(dx, x, p)
cache.f(cache.fuₙ, cache.uₙ)
cache.f(cache.fuₙ₋₁, cache.uₙ)
cache.prob.f(cache.fuₙ, cache.uₙ, p)
cache.prob.f(cache.fuₙ₋₁, cache.uₙ, p)
else
cache.uₙ = u0
cache.uₙ₋₁ = u0
cache.f = (x) -> cache.prob.f(x, p)
cache.fuₙ = cache.f(cache.uₙ)
cache.fuₙ₋₁ = cache.f(cache.uₙ)
cache.fuₙ = cache.prob.f(cache.uₙ, p)
cache.fuₙ₋₁ = cache.prob.f(cache.uₙ, p)
end

cache.f₍ₙₒᵣₘ₎ₙ₋₁ = norm(cache.fuₙ₋₁)^cache.nₑₓₚ
Expand Down
Loading

0 comments on commit 7b8ca58

Please sign in to comment.