From 890ae75efab759a55b4e54e6684a3633405aee03 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 14 Oct 2023 20:17:15 -0400 Subject: [PATCH] Add tests --- Project.toml | 3 ++- ext/NonlinearSolveLeastSquaresOptimExt.jl | 9 +++++--- test/nonlinear_least_squares.jl | 28 ++++++++--------------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index bd2fca7ae..b4f7c2553 100644 --- a/Project.toml +++ b/Project.toml @@ -58,6 +58,7 @@ julia = "1.9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141" @@ -71,4 +72,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary"] +test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim"] diff --git a/ext/NonlinearSolveLeastSquaresOptimExt.jl b/ext/NonlinearSolveLeastSquaresOptimExt.jl index b0e562a3b..7514931d5 100644 --- a/ext/NonlinearSolveLeastSquaresOptimExt.jl +++ b/ext/NonlinearSolveLeastSquaresOptimExt.jl @@ -41,9 +41,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver f! = FunctionWrapper{iip}(prob.f, prob.p) g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p) - lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = prob.f.resid_prototype, g!, - J = prob.f.jac_prototype, alg.autodiff, - output_length = length(prob.f.resid_prototype)) + resid_prototype = prob.f.resid_prototype === nothing ? + (!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) : + prob.f.resid_prototype + + lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!, + J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype)) allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg)) return LeastSquaresOptimCache(prob, alg, allocated_prob, diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index 27775bc40..c9fda61ba 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -1,4 +1,5 @@ using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random +import LeastSquaresOptim true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])) @@ -25,22 +26,11 @@ prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x) prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) -sol = solve(prob_oop, GaussNewton(; linsolve = NormalCholeskyFactorization()); - maxiters = 1000, abstol = 1e-8) -@test SciMLBase.successful_retcode(sol) -@test norm(sol.resid) < 1e-6 - -sol = solve(prob_iip, GaussNewton(; linsolve = NormalCholeskyFactorization()); - maxiters = 1000, abstol = 1e-8) -@test SciMLBase.successful_retcode(sol) -@test norm(sol.resid) < 1e-6 - -sol = solve(prob_oop, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization()); - maxiters = 1000, abstol = 1e-8) -@test SciMLBase.successful_retcode(sol) -@test norm(sol.resid) < 1e-6 - -sol = solve(prob_iip, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization()); - maxiters = 1000, abstol = 1e-8) -@test SciMLBase.successful_retcode(sol) -@test norm(sol.resid) < 1e-6 +nlls_problems = [prob_oop, prob_iip] +solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)] + +for prob in nlls_problems, solver in solvers + @time sol = solve(prob, solver; maxiters = 1000, abstol = 1e-8) + @test SciMLBase.successful_retcode(sol) + @test norm(sol.resid) < 1e-6 +end