From 8ac858d2c3128f5947f2b7784f9dd4fe3ec06e94 Mon Sep 17 00:00:00 2001 From: termi-official Date: Fri, 7 Jun 2024 01:08:16 +0200 Subject: [PATCH] Fix returncodes and stats for iterative solvers and add test coverage. --- src/iterative_wrappers.jl | 2 +- test/retcodes.jl | 46 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 test/retcodes.jl diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 3c03ea7d6..bb93ba632 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -306,5 +306,5 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) end return SciMLBase.build_linear_solution(alg, cache.u, resid, cache; - iters = stats.niter) + iters = stats.niter, retcode, stats) end diff --git a/test/retcodes.jl b/test/retcodes.jl new file mode 100644 index 000000000..ab9eeaa15 --- /dev/null +++ b/test/retcodes.jl @@ -0,0 +1,46 @@ +@testset "Return codes" begin + alglist = ( + LUFactorization, + QRFactorization, + DiagonalFactorization, + DirectLdiv!, + SparspakFactorization, + KLUFactorization, + UMFPACKFactorization, + KrylovJL_GMRES, + GenericLUFactorization, + RFLUFactorization, + LDLtFactorization, + BunchKaufmanFactorization, + CHOLMODFactorization, + SVDFactorization, + CholeskyFactorization, + NormalCholeskyFactorization, + AppleAccelerateLUFactorization, + MKLLUFactorization, + KrylovJL_CRAIGMR, + KrylovJL_LSMR, + ) + + @testset "Success" begin + for alg in alglist + A = [2.0 1.0; -1.0 1.0] + b = [-1.0, 1.0] + prob = LinearProblem(A, b) + linsolve = init(prob, alg) + sol = solve!(linsolve) + @test SciMLBase.successful_retcode(sol.retcode) || sol.retcode == ReturnCode.Default # The latter seems off... + end + end + + @testset "Failure" begin + for alg in alglist + A = [1.0 1.0; 1.0 1.0] + b = [-1.0, 1.0] + prob = LinearProblem(A, b) + linsolve = init(prob, alg) + sol = solve!(linsolve) + @test !SciMLBase.successful_retcode(sol.retcode) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 4994eba23..8d7b626fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension) if GROUP == "All" || GROUP == "Core" @time @safetestset "Quality Assurance" include("qa.jl") @time @safetestset "Basic Tests" include("basictests.jl") + @time @safetestset "Return codes" include("retcodes.jl") @time @safetestset "Re-solve" include("resolve.jl") @time @safetestset "Zero Initialization Tests" include("zeroinittests.jl") @time @safetestset "Non-Square Tests" include("nonsquare.jl")