Skip to content

Commit

Permalink
Special case for static arrays in FastLM
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 15, 2024
1 parent acb4737 commit 1663669
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
69 changes: 35 additions & 34 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ArrayInterface, NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import FastClosures: @closure
import FastLevenbergMarquardt as FastLM
import StaticArraysCore: StaticArray
import StaticArraysCore: SArray

@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
if linsolve === :cholesky
Expand All @@ -15,53 +15,54 @@ import StaticArraysCore: StaticArray
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
end
end
@inline _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, ::SArray) where {linsolve} = linsolve

Check warning on line 18 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L18

Added line #L18 was not covered by tests

# TODO: Implement reinit
@concrete struct FastLevenbergMarquardtJLCache
f!
J!
prob
alg
lmworkspace
solver
kwargs
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
function SciMLBase.__solve(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing,
reltol = nothing, maxiters = 1000, termination_condition = nothing, kwargs...)
NonlinearSolve.__test_termination_condition(termination_condition,
:FastLevenbergMarquardt)
if prob.u0 isa StaticArray # FIXME
error("FastLevenbergMarquardtJL does not support StaticArrays yet.")
end

_f!, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
f! = @closure (du, u, p) -> _f!(du, u)
fn, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0,
can_handle_oop = Val(prob.u0 isa SArray))
f = if prob.u0 isa SArray
@closure (u, p) -> fn(u)
else
@closure (du, u, p) -> fn(du, u)
end
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))

_J! = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid; alg.autodiff)
J! = @closure (J, u, p) -> _J!(J, u)
J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) :
zero(prob.f.jac_prototype)
_jac_fn = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid; alg.autodiff,
can_handle_oop = Val(prob.u0 isa SArray))
jac_fn = if prob.u0 isa SArray
@closure (u, p) -> _jac_fn(u)
else
@closure (J, u, p) -> _jac_fn(J, u)
end

solver = _fast_lm_solver(alg, u)
LM = FastLM.LMWorkspace(u, resid, J)
solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
alg.factor, alg.factoraccept, alg.factorreject, alg.minscale, alg.maxscale,
alg.factorupdate, alg.minfactor, alg.maxfactor)

return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
(; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, alg.factor,
alg.factoraccept, alg.factorreject, alg.minscale, alg.maxscale,
alg.factorupdate, alg.minfactor, alg.maxfactor))
end
if prob.u0 isa SArray
res, fx, info, iter, nfev, njev = FastLM.lmsolve(f, jac_fn, prob.u0;
solver_kwargs...)
LM, solver = nothing, nothing
else
J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) :
zero(prob.f.jac_prototype)
solver = _fast_lm_solver(alg, u)
LM = FastLM.LMWorkspace(u, resid, J)

res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(f, jac_fn, LM;
solver, solver_kwargs...)
end

function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache)
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
cache.lmworkspace; cache.solver, cache.kwargs...)
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
return SciMLBase.build_solution(prob, alg, res, fx; retcode,
original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
end

end
25 changes: 21 additions & 4 deletions test/wrappers/nlls.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NonlinearSolve, LinearAlgebra, Test, StableRNGs, Random, ForwardDiff, Zygote
using NonlinearSolve,
LinearAlgebra, Test, StableRNGs, StaticArrays, Random, ForwardDiff, Zygote
import FastLevenbergMarquardt, LeastSquaresOptim, MINPACK

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
Expand All @@ -8,7 +9,7 @@ true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]

x = [-1.0, -0.5, 0.0, 0.5, 1.0]

y_target = true_function(x, θ_true)
const y_target = true_function(x, θ_true)

function loss_function(θ, p)
= true_function(p, θ)
Expand All @@ -34,7 +35,7 @@ autodiff in (nothing, AutoForwardDiff(), AutoFiniteDiff(), :central, :forward)]
for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6
@test norm(sol.resid, Inf) < 1e-6
end

function jac!(J, θ, p)
Expand Down Expand Up @@ -76,5 +77,21 @@ append!(solvers, [CMINPACK(; method) for method in (:auto, :lm, :lmdif)])

for solver in solvers, prob in probs
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test norm(sol.resid) < 1e-6
@test norm(sol.resid, Inf) < 1e-6
end

# Static Arrays -- Fast Levenberg-Marquardt
x_sa = SA[-1.0, -0.5, 0.0, 0.5, 1.0]

const y_target_sa = true_function(x_sa, θ_true)

function loss_function_sa(θ, p)
= true_function(p, θ)
return.- y_target_sa
end

θ_init_sa = SVector{4}(θ_init)
prob_sa = NonlinearLeastSquaresProblem{false}(loss_function_sa, θ_init_sa, x)

@time sol = solve(prob_sa, FastLevenbergMarquardtJL())
@test norm(sol.resid, Inf) < 1e-6

0 comments on commit 1663669

Please sign in to comment.