Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2023
1 parent a9d884f commit 890ae75
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 23 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ julia = "1.9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Expand All @@ -71,4 +72,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim"]
9 changes: 6 additions & 3 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver
f! = FunctionWrapper{iip}(prob.f, prob.p)
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)

lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = prob.f.resid_prototype, g!,
J = prob.f.jac_prototype, alg.autodiff,
output_length = length(prob.f.resid_prototype))
resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
prob.f.resid_prototype

lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!,
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))

return LeastSquaresOptimCache(prob, alg, allocated_prob,
Expand Down
28 changes: 9 additions & 19 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random
import LeastSquaresOptim

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
Expand All @@ -25,22 +26,11 @@ prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

sol = solve(prob_oop, GaussNewton(; linsolve = NormalCholeskyFactorization());
maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6

sol = solve(prob_iip, GaussNewton(; linsolve = NormalCholeskyFactorization());
maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6

sol = solve(prob_oop, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization());
maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6

sol = solve(prob_iip, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization());
maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6
nlls_problems = [prob_oop, prob_iip]
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 1000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6
end

0 comments on commit 890ae75

Please sign in to comment.