Skip to content

Commit

Permalink
Delegate common operations to an abstract type
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 21, 2023
1 parent 4392343 commit d0d3db4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 47 deletions.
20 changes: 20 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,32 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end

abstract type AbstractNonlinearSolveCache{iip} end

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end

include("utils.jl")
include("linesearch.jl")
include("raphson.jl")
Expand Down
21 changes: 2 additions & 19 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end

@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType}
@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} <:
AbstractNonlinearSolveCache{iip}
f
alg
u::uType
Expand Down Expand Up @@ -138,8 +139,6 @@ end
stats::NLStats
end

isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
Expand Down Expand Up @@ -313,19 +312,3 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.λ_factor = cache.damping_increase_factor
return nothing
end

function SciMLBase.solve!(cache::LevenbergMarquardtCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end
26 changes: 1 addition & 25 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ for large-scale and numerically-difficult nonlinear systems.
linesearch
end

concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ

function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
end

@concrete mutable struct NewtonRaphsonCache{iip}
@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
Expand All @@ -67,8 +65,6 @@ end
lscache
end

isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
Expand Down Expand Up @@ -146,23 +142,3 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache)
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end

function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
5 changes: 2 additions & 3 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU
expand_threshold, shrink_factor, expand_factor, max_shrink_times)
end

@concrete mutable struct TrustRegionCache{iip, trustType, floatType}
@concrete mutable struct TrustRegionCache{iip, trustType, floatType} <:
AbstractNonlinearSolveCache{iip}
f
alg
u_prev
Expand Down Expand Up @@ -299,8 +300,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
NLStats(1, 0, 0, 0, 0))
end

isinplace(::TrustRegionCache{iip}) where {iip} = iip

function perform_step!(cache::TrustRegionCache{true})
@unpack make_new_J, J, fu, f, u, p, u_tmp, alg, linsolve = cache
if cache.make_new_J
Expand Down

0 comments on commit d0d3db4

Please sign in to comment.