From 98ef80535b304e2b12bb356787f50b9122e713d8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Sep 2024 16:50:29 -0400 Subject: [PATCH] feat: functional Klement --- .../src/NonlinearSolveBase.jl | 2 +- lib/NonlinearSolveBase/src/public.jl | 4 +- .../src/termination_conditions.jl | 2 +- .../src/SimpleNonlinearSolve.jl | 6 +- lib/SimpleNonlinearSolve/src/klement.jl | 5 +- lib/SimpleNonlinearSolve/src/utils.jl | 57 ++++++++++++++----- src/internal/termination.jl | 4 +- 7 files changed, 55 insertions(+), 25 deletions(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 1ba7a0cc3..63f4b697c 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -10,7 +10,7 @@ using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction, @add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem, - isinplace + isinplace, warn_paramtype using StaticArraysCore: StaticArray include("public.jl") diff --git a/lib/NonlinearSolveBase/src/public.jl b/lib/NonlinearSolveBase/src/public.jl index db8c389e2..d9014d71e 100644 --- a/lib/NonlinearSolveBase/src/public.jl +++ b/lib/NonlinearSolveBase/src/public.jl @@ -51,7 +51,7 @@ for name in (:Norm, :RelNorm, :AbsNorm) @eval begin """ - $($struct_name) <: AbstractSafeNonlinearTerminationMode + $($struct_name) <: AbstractNonlinearTerminationMode Terminates if $($doctring). @@ -63,7 +63,7 @@ for name in (:Norm, :RelNorm, :AbsNorm) $($TERM_INTERNALNORM_DOCS). """ - struct $(struct_name){F} <: AbstractSafeNonlinearTerminationMode + struct $(struct_name){F} <: AbstractNonlinearTerminationMode internalnorm::F function $(struct_name)(internalnorm::F) where {F} diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index a278861a8..4403e12c3 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -276,6 +276,6 @@ function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du, T = promote_type(eltype(du), eltype(u)) abstol = get_tolerance(abstol, T) reltol = get_tolerance(reltol, T) - cache = init(du, u, tc; abstol, reltol) + cache = SciMLBase.init(du, u, tc; abstol, reltol) return abstol, reltol, cache end diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 4a9f369f1..4b524e4bf 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -18,7 +18,7 @@ using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder -using NonlinearSolveBase: ImmutableNonlinearProblem, get_tolerance +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance const DI = DifferentiationInterface @@ -28,6 +28,8 @@ is_extension_loaded(::Val) = false include("utils.jl") +include("klement.jl") + # By Pass the highlevel checks for NonlinearProblem for Simple Algorithms function CommonSolve.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) @@ -69,7 +71,7 @@ function solve_adjoint_internal end prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2)) prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2)) - algs = [] + algs = [SimpleKlement()] algs_no_iip = [] @compile_workload begin diff --git a/lib/SimpleNonlinearSolve/src/klement.jl b/lib/SimpleNonlinearSolve/src/klement.jl index feb93d423..055f65bc3 100644 --- a/lib/SimpleNonlinearSolve/src/klement.jl +++ b/lib/SimpleNonlinearSolve/src/klement.jl @@ -11,6 +11,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, alias_u0 = false, termination_condition = nothing, kwargs...) x = Utils.maybe_unaliased(prob.u0, alias_u0) T = eltype(x) + fx = Utils.get_fx(prob, x) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) @@ -32,8 +33,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, fx = Utils.eval_f(prob, fx, x) # Termination Checks - # tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg) - tc_sol !== nothing && return tc_sol + solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob) + solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode) @bb δx .*= -1 @bb @. δx² = δx^2 * J^2 diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index a12c90f78..9008171d3 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -5,9 +5,12 @@ using ArrayInterface: ArrayInterface using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure using LinearAlgebra: LinearAlgebra, I, diagind -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, + AbstractNonlinearTerminationMode, + AbstractSafeNonlinearTerminationMode, + AbstractSafeBestNonlinearTerminationMode using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearLeastSquaresProblem, - NonlinearProblem, NonlinearFunction + NonlinearProblem, NonlinearFunction, ReturnCode using StaticArraysCore: StaticArray, SArray, SMatrix, SVector const DI = DifferentiationInterface @@ -60,7 +63,7 @@ function get_fx(prob::Union{ImmutableNonlinearProblem, NonlinearProblem}, x) end function get_fx(f::NonlinearFunction, x, p) if SciMLBase.isinplace(f) - f.resid_prototype === nothing && return eltype(x).(f.resid_prototype) + f.resid_prototype === nothing || return eltype(x).(f.resid_prototype) return safe_similar(x) end return f(x, p) @@ -77,18 +80,18 @@ function fixed_parameter_function(prob::AbstractNonlinearProblem) return Base.Fix2(prob.f, prob.p) end -# __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α) -# function __init_identity_jacobian(u, fu, α = true) -# J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u)) -# fill!(J, zero(eltype(J))) -# J[diagind(J)] .= eltype(J)(α) -# return J -# end -# function __init_identity_jacobian(u::StaticArray, fu, α = true) -# S1, S2 = length(fu), length(u) -# J = SMatrix{S1, S2, eltype(u)}(I * α) -# return J -# end +function identity_jacobian(u::Number, fu::Number, α = true) + return convert(promote_type(eltype(u), eltype(fu)), α) +end +function identity_jacobian(u, fu, α = true) + J = safe_similar(u, promote_type(eltype(u), eltype(fu))) + fill!(J, zero(eltype(J))) + J[diagind(J)] .= eltype(J)(α) + return J +end +function identity_jacobian(u::StaticArray, fu, α = true) + return SMatrix{length(fu), length(u), eltype(u)}(I * α) +end identity_jacobian!!(J::Number) = one(J) function identity_jacobian!!(J::AbstractVector) @@ -104,4 +107,28 @@ end identity_jacobian!!(::SMatrix{S1, S2, T}) where {S1, S2, T} = SMatrix{S1, S2, T}(I) identity_jacobian!!(::SVector{S1, T}) where {S1, T} = ones(SVector{S1, T}) +# Termination Conditions +function check_termination(cache, fx, x, xo, prob) + return check_termination(cache, fx, x, xo, prob, cache.mode) +end + +function check_termination(cache, fx, x, xo, _, ::AbstractNonlinearTerminationMode) + return cache(fx, x, xo), ReturnCode.Success, fx, x +end +function check_termination(cache, fx, x, xo, _, ::AbstractSafeNonlinearTerminationMode) + return cache(fx, x, xo), cache.retcode, fx, x +end +function check_termination(cache, fx, x, xo, prob, ::AbstractSafeBestNonlinearTerminationMode) + if cache(fx, x, xo) + x = cache.u + if SciMLBase.isinplace(prob) + prob.f(fx, x, prob.p) + else + fx = prob.f(x, prob.p) + end + return true, cache.retcode, fx, x + end + return false, ReturnCode.Default, fx, x +end + end diff --git a/src/internal/termination.jl b/src/internal/termination.jl index 570bae110..e09cfcdda 100644 --- a/src/internal/termination.jl +++ b/src/internal/termination.jl @@ -45,12 +45,12 @@ function update_from_termination_cache!(tc_cache, cache, u = get_u(cache)) end function update_from_termination_cache!( - tc_cache, cache, mode::AbstractNonlinearTerminationMode, u = get_u(cache)) + _, cache, ::AbstractNonlinearTerminationMode, u = get_u(cache)) evaluate_f!(cache, u, cache.p) end function update_from_termination_cache!( - tc_cache, cache, mode::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache)) + tc_cache, cache, ::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache)) if isinplace(cache) copyto!(get_u(cache), tc_cache.u) else