Skip to content

Commit

Permalink
feat: functional Klement
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 25, 2024
1 parent ca8e8bc commit 98ef805
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 25 deletions.
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction,
@add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem,
isinplace
isinplace, warn_paramtype
using StaticArraysCore: StaticArray

include("public.jl")
Expand Down
4 changes: 2 additions & 2 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ for name in (:Norm, :RelNorm, :AbsNorm)

@eval begin
"""
$($struct_name) <: AbstractSafeNonlinearTerminationMode
$($struct_name) <: AbstractNonlinearTerminationMode
Terminates if $($doctring).
Expand All @@ -63,7 +63,7 @@ for name in (:Norm, :RelNorm, :AbsNorm)
$($TERM_INTERNALNORM_DOCS).
"""
struct $(struct_name){F} <: AbstractSafeNonlinearTerminationMode
struct $(struct_name){F} <: AbstractNonlinearTerminationMode
internalnorm::F

function $(struct_name)(internalnorm::F) where {F}
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,6 @@ function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du,
T = promote_type(eltype(du), eltype(u))
abstol = get_tolerance(abstol, T)
reltol = get_tolerance(reltol, T)
cache = init(du, u, tc; abstol, reltol)
cache = SciMLBase.init(du, u, tc; abstol, reltol)
return abstol, reltol, cache
end
6 changes: 4 additions & 2 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff

using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
using NonlinearSolveBase: ImmutableNonlinearProblem, get_tolerance
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance

const DI = DifferentiationInterface

Expand All @@ -28,6 +28,8 @@ is_extension_loaded(::Val) = false

include("utils.jl")

include("klement.jl")

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function CommonSolve.solve(prob::NonlinearProblem,
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
Expand Down Expand Up @@ -69,7 +71,7 @@ function solve_adjoint_internal end
prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2))
prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2))

algs = []
algs = [SimpleKlement()]
algs_no_iip = []

@compile_workload begin
Expand Down
5 changes: 3 additions & 2 deletions lib/SimpleNonlinearSolve/src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = Utils.maybe_unaliased(prob.u0, alias_u0)
T = eltype(x)
fx = Utils.get_fx(prob, x)

abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
Expand All @@ -32,8 +33,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement,
fx = Utils.eval_f(prob, fx, x)

# Termination Checks
# tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
tc_sol !== nothing && return tc_sol
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)

@bb δx .*= -1
@bb @. δx² = δx^2 * J^2
Expand Down
57 changes: 42 additions & 15 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ using ArrayInterface: ArrayInterface
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra, I, diagind
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearLeastSquaresProblem,
NonlinearProblem, NonlinearFunction
NonlinearProblem, NonlinearFunction, ReturnCode
using StaticArraysCore: StaticArray, SArray, SMatrix, SVector

const DI = DifferentiationInterface
Expand Down Expand Up @@ -60,7 +63,7 @@ function get_fx(prob::Union{ImmutableNonlinearProblem, NonlinearProblem}, x)
end
function get_fx(f::NonlinearFunction, x, p)
if SciMLBase.isinplace(f)
f.resid_prototype === nothing && return eltype(x).(f.resid_prototype)
f.resid_prototype === nothing || return eltype(x).(f.resid_prototype)
return safe_similar(x)
end
return f(x, p)
Expand All @@ -77,18 +80,18 @@ function fixed_parameter_function(prob::AbstractNonlinearProblem)
return Base.Fix2(prob.f, prob.p)
end

# __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α)
# function __init_identity_jacobian(u, fu, α = true)
# J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
# fill!(J, zero(eltype(J)))
# J[diagind(J)] .= eltype(J)(α)
# return J
# end
# function __init_identity_jacobian(u::StaticArray, fu, α = true)
# S1, S2 = length(fu), length(u)
# J = SMatrix{S1, S2, eltype(u)}(I * α)
# return J
# end
function identity_jacobian(u::Number, fu::Number, α = true)
return convert(promote_type(eltype(u), eltype(fu)), α)
end
function identity_jacobian(u, fu, α = true)
J = safe_similar(u, promote_type(eltype(u), eltype(fu)))
fill!(J, zero(eltype(J)))
J[diagind(J)] .= eltype(J)(α)
return J
end
function identity_jacobian(u::StaticArray, fu, α = true)
return SMatrix{length(fu), length(u), eltype(u)}(I * α)
end

identity_jacobian!!(J::Number) = one(J)
function identity_jacobian!!(J::AbstractVector)
Expand All @@ -104,4 +107,28 @@ end
identity_jacobian!!(::SMatrix{S1, S2, T}) where {S1, S2, T} = SMatrix{S1, S2, T}(I)
identity_jacobian!!(::SVector{S1, T}) where {S1, T} = ones(SVector{S1, T})

# Termination Conditions
function check_termination(cache, fx, x, xo, prob)
return check_termination(cache, fx, x, xo, prob, cache.mode)
end

function check_termination(cache, fx, x, xo, _, ::AbstractNonlinearTerminationMode)
return cache(fx, x, xo), ReturnCode.Success, fx, x
end
function check_termination(cache, fx, x, xo, _, ::AbstractSafeNonlinearTerminationMode)
return cache(fx, x, xo), cache.retcode, fx, x
end
function check_termination(cache, fx, x, xo, prob, ::AbstractSafeBestNonlinearTerminationMode)
if cache(fx, x, xo)
x = cache.u
if SciMLBase.isinplace(prob)
prob.f(fx, x, prob.p)
else
fx = prob.f(x, prob.p)
end
return true, cache.retcode, fx, x
end
return false, ReturnCode.Default, fx, x
end

end
4 changes: 2 additions & 2 deletions src/internal/termination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ function update_from_termination_cache!(tc_cache, cache, u = get_u(cache))
end

function update_from_termination_cache!(
tc_cache, cache, mode::AbstractNonlinearTerminationMode, u = get_u(cache))
_, cache, ::AbstractNonlinearTerminationMode, u = get_u(cache))
evaluate_f!(cache, u, cache.p)
end

function update_from_termination_cache!(
tc_cache, cache, mode::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache))
tc_cache, cache, ::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache))
if isinplace(cache)
copyto!(get_u(cache), tc_cache.u)
else
Expand Down

0 comments on commit 98ef805

Please sign in to comment.