Skip to content

Commit

Permalink
feat: add partial SimpleKlement Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 17, 2024
1 parent 4f970f7 commit fe515f2
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 3 deletions.
16 changes: 16 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -36,12 +40,24 @@ ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
DiffEqBase = "6.155"
DifferentiationInterface = "0.5.17"
FastClosures = "0.3.2"
FiniteDiff = "2.24.0"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
MaybeInplace = "0.1.4"
NonlinearSolveBase = "1"
PrecompileTools = "1.2"
Reexport = "1.2"
ReverseDiff = "1.15"
SciMLBase = "2.50"
StaticArraysCore = "1.4.3"
Tracker = "0.2.35"
julia = "1.10"

[extras]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[targets]
test = ["InteractiveUtils", "Test", "TestItemRunner"]
10 changes: 7 additions & 3 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
module SimpleNonlinearSolve

using CommonSolve: CommonSolve, solve
using FastClosures: @closure
using MaybeInplace: @bb
using PrecompileTools: @compile_workload, @setup_workload
using Reexport: @reexport
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
using StaticArraysCore: StaticArray

# AD Dependencies
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using ADTypes: AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
using DifferentiationInterface: DifferentiationInterface
# TODO: move these to extensions in a breaking change. These are not even used in the
# package, but are used to trigger the extension loading in DI.jl
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff

using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
using NonlinearSolveBase: ImmutableNonlinearProblem
using NonlinearSolveBase: ImmutableNonlinearProblem, get_tolerance

const DI = DifferentiationInterface

abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

is_extension_loaded(::Val) = false

include("utils.jl")

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function CommonSolve.solve(prob::NonlinearProblem,
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
Expand Down
47 changes: 47 additions & 0 deletions lib/SimpleNonlinearSolve/src/klement.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
SimpleKlement()
A low-overhead implementation of `Klement` [klement2014using](@citep). This
method is non-allocating on scalar and static array problems.
"""
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = Utils.maybe_unaliased(prob.u0, alias_u0)
T = eltype(x)

abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))

@bb δx = copy(x)
@bb fprev = copy(fx)
@bb xo = copy(x)
@bb d = copy(x)

J = one.(x)
@bb δx² = similar(x)

for _ in 1:maxiters
any(iszero, J) && (J = Utils.identity_jacobian!!(J))

@bb @. δx = fprev / J

@bb @. x = xo - δx
fx = Utils.eval_f(prob, fx, x)

# Termination Checks
# tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
tc_sol !== nothing && return tc_sol

@bb δx .*= -1
@bb @. δx² = δx^2 * J^2
@bb @. J += (fx - fprev - J * δx) / ifelse(iszero(δx²), T(1e-5), δx²) * δx * (J^2)

@bb copyto!(fprev, fx)
@bb copyto!(xo, x)
end

return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end
107 changes: 107 additions & 0 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
module Utils

using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra, I, diagind
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem,
NonlinearFunction
using StaticArraysCore: StaticArray, SArray, SMatrix, SVector

const DI = DifferentiationInterface

const safe_similar = NonlinearSolveBase.Utils.safe_similar

pickchunksize(n::Int) = min(n, 12)

can_dual(::Type{<:Real}) = true
can_dual(::Type) = false

maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x
function maybe_unaliased(x::T, alias::Bool) where {T <: AbstractArray}
(alias || !ArrayInterface.can_setindex(T)) && return x
return copy(x)
end

function get_concrete_autodiff(_, ad::AbstractADType)
DI.check_available(ad) && return ad
error("AD Backend $(ad) is not available. This could be because you haven't loaded the \
actual backend (See [Differentiation Inferface Docs](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/) \

Check warning on line 31 in lib/SimpleNonlinearSolve/src/utils.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Inferface" should be "Interface".
for more details) or the backend might not be supported by DifferentiationInferface.jl.")

Check warning on line 32 in lib/SimpleNonlinearSolve/src/utils.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Inferface" should be "Interface".
end
function get_concrete_autodiff(
prob, ad::Union{AutoForwardDiff{nothing}, AutoPolyesterForwardDiff{nothing}})
return get_concrete_autodiff(prob,
ArrayInterface.parameterless_type(ad)(;
chunksize = pickchunksize(length(prob.u0)), ad.tag))
end
function get_concrete_autodiff(prob, ::Nothing)
if can_dual(eltype(prob.u0)) && DI.check_available(AutoForwardDiff())
return AutoForwardDiff(; chunksize = pickchunksize(length(prob.u0)))
end
DI.check_available(AutoFiniteDiff()) && return AutoFiniteDiff()
error("Default AD backends are not available. Please load either FiniteDiff or \
ForwardDiff for default AD selection to work. Else provide a specific AD \
backend (instead of `nothing`) to the solver.")
end

# NOTE: This doesn't initialize the `f(x)` but just returns a buffer of the same size
function get_fx(prob::NonlinearLeastSquaresProblem, x)
if SciMLBase.isinplace(prob) && prob.f.resid_prototype === nothing
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype` to be \
specified.")
end
return get_fx(prob.f, x, prob.p)
end
function get_fx(prob::Union{ImmutableNonlinearProblem, NonlinearProblem}, x)
return get_fx(prob.f, x, prob.p)
end
function get_fx(f::NonlinearFunction, x, p)
if SciMLBase.isinplace(f)
f.resid_prototype === nothing && return eltype(x).(f.resid_prototype)
return safe_similar(x)
end
return f(x, p)
end

function eval_f(prob, fx, x)
SciMLBase.isinplace(prob) || return prob.f(x, prob.p)
prob.f(fx, x, prob.p)
return fx
end

function fixed_parameter_function(prob::AbstractNonlinearProblem)
SciMLBase.isinplace(prob) && return @closure (du, u) -> prob.f(du, u, prob.p)
return Base.Fix2(prob.f, prob.p)
end

# __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α)
# function __init_identity_jacobian(u, fu, α = true)
# J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
# fill!(J, zero(eltype(J)))
# J[diagind(J)] .= eltype(J)(α)
# return J
# end
# function __init_identity_jacobian(u::StaticArray, fu, α = true)
# S1, S2 = length(fu), length(u)
# J = SMatrix{S1, S2, eltype(u)}(I * α)
# return J
# end

identity_jacobian!!(J::Number) = one(J)
function identity_jacobian!!(J::AbstractVector)
ArrayInterface.can_setindex(J) || return one.(J)
fill!(J, true)
return J
end
function identity_jacobian!!(J::AbstractMatrix)
ArrayInterface.can_setindex(J) || return convert(typeof(J), I)
J[diagind(J)] .= true
return J
end
identity_jacobian!!(::SMatrix{S1, S2, T}) where {S1, S2, T} = SMatrix{S1, S2, T}(I)
identity_jacobian!!(::SVector{S1, T}) where {S1, T} = ones(SVector{S1, T})

end
4 changes: 4 additions & 0 deletions lib/SimpleNonlinearSolve/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
using TestItemRunner, InteractiveUtils

@info sprint(InteractiveUtils.versioninfo)

@run_package_tests

0 comments on commit fe515f2

Please sign in to comment.