From ab451fe5a4d1a9bae8a603b2bbbfd6ef54d5bea3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 2 Jun 2023 08:02:21 -0700 Subject: [PATCH] add JET tests and rely less on constant prop --- Project.toml | 5 +++-- src/LinearSolve.jl | 25 ++++++++++++++++++++- src/common.jl | 2 +- src/default.jl | 52 -------------------------------------------- src/factorization.jl | 42 +++++++++++++++++++++-------------- test/default_algs.jl | 15 ++++++++++++- 6 files changed, 68 insertions(+), 73 deletions(-) diff --git a/Project.toml b/Project.toml index 24004dd9e..1a587ca84 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.1.0" +version = "2.1.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -60,6 +60,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" @@ -69,7 +70,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"] [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 5c25f52c0..09513c6de 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -55,6 +55,29 @@ _isidentity_struct(::SciMLBase.DiffEqIdentity) = true const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS) +EnumX.@enumx DefaultAlgorithmChoice begin + LUFactorization + QRFactorization + DiagonalFactorization + DirectLdiv! + SparspakFactorization + KLUFactorization + UMFPACKFactorization + KrylovJL_GMRES + GenericLUFactorization + RFLUFactorization + LDLtFactorization + BunchKaufmanFactorization + CHOLMODFactorization + SVDFactorization + CholeskyFactorization + NormalCholeskyFactorization +end + +struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm + alg::DefaultAlgorithmChoice.T +end + include("common.jl") include("factorization.jl") include("simplelu.jl") @@ -74,7 +97,7 @@ include("deprecated.jl") cache.cacheval = fact cache.isfresh = false end - y = _ldiv!(cache.u, get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))), + y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))), cache.b) #= diff --git a/src/common.jl b/src/common.jl index 210e5c52e..262fa2011 100644 --- a/src/common.jl +++ b/src/common.jl @@ -121,7 +121,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, verbose::Bool = false, Pl = IdentityOperator(size(prob.A)[1]), Pr = IdentityOperator(size(prob.A)[2]), - assumptions = OperatorAssumptions(Val(issquare(prob.A))), + assumptions = OperatorAssumptions(issquare(prob.A)), kwargs...) @unpack A, b, u0, p = prob diff --git a/src/default.jl b/src/default.jl index a4ab615aa..4724ce747 100644 --- a/src/default.jl +++ b/src/default.jl @@ -1,26 +1,3 @@ -EnumX.@enumx DefaultAlgorithmChoice begin - LUFactorization - QRFactorization - DiagonalFactorization - DirectLdiv! - SparspakFactorization - KLUFactorization - UMFPACKFactorization - KrylovJL_GMRES - GenericLUFactorization - RFLUFactorization - LDLtFactorization - BunchKaufmanFactorization - CHOLMODFactorization - SVDFactorization - CholeskyFactorization - NormalCholeskyFactorization -end - -struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm - alg::DefaultAlgorithmChoice.T -end - mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16} LUFactorization::T1 @@ -313,35 +290,6 @@ cache.cacheval = NamedTuple(LUFactorization = cache of LUFactorization, ...) Expr(:call, :DefaultLinearSolverInit, caches...) end -""" -if algsym === :LUFactorization - cache.cacheval.LUFactorization = ... -else - ... -end -""" -@generated function get_cacheval(cache::LinearCache, algsym::Symbol) - ex = :() - for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) - ex = if ex == :() - Expr(:elseif, :(algsym === $(Meta.quot(alg))), - :(getfield(cache.cacheval, $(Meta.quot(alg))))) - else - Expr(:elseif, :(algsym === $(Meta.quot(alg))), - :(getfield(cache.cacheval, $(Meta.quot(alg)))), ex) - end - end - ex = Expr(:if, ex.args...) - - quote - if cache.alg isa DefaultLinearSolver - $ex - else - cache.cacheval - end - end -end - function defaultalg_symbol(::Type{T}) where {T} Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end]) end diff --git a/src/factorization.jl b/src/factorization.jl index b95c7ded7..c79566ec0 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -1,3 +1,13 @@ +macro get_cacheval(cache, algsym) + quote + if $(esc(cache)).alg isa DefaultLinearSolver + getfield($(esc(cache)).cacheval, $algsym) + else + $(esc(cache)).cacheval + end + end +end + _ldiv!(x, A, b) = ldiv!(x, A, b) function _ldiv!(x::Vector, A::Factorization, b::Vector) @@ -712,11 +722,11 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs. if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) == cache.cacheval.colptr && SuiteSparse.decrement(SparseArrays.getrowval(A)) == - get_cacheval(cache, :UMFPACKFactorization).rowval) + @get_cacheval(cache, :UMFPACKFactorization).rowval) fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))) else - fact = lu!(get_cacheval(cache, :UMFPACKFactorization), + fact = lu!(@get_cacheval(cache, :UMFPACKFactorization), SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))) end @@ -727,7 +737,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs. cache.isfresh = false end - y = ldiv!(cache.u, get_cacheval(cache, :UMFPACKFactorization), cache.b) + y = ldiv!(cache.u, @get_cacheval(cache, :UMFPACKFactorization), cache.b) SciMLBase.build_linear_solution(alg, y, nothing, cache) end @@ -782,7 +792,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...) A = convert(AbstractMatrix, A) if cache.isfresh - cacheval = get_cacheval(cache, :KLUFactorization) + cacheval = @get_cacheval(cache, :KLUFactorization) if cacheval !== nothing && alg.reuse_symbolic if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) == cacheval.colptr && @@ -811,7 +821,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...) cache.isfresh = false end - y = ldiv!(cache.u, get_cacheval(cache, :KLUFactorization), cache.b) + y = ldiv!(cache.u, @get_cacheval(cache, :KLUFactorization), cache.b) SciMLBase.build_linear_solution(alg, y, nothing, cache) end @@ -863,7 +873,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs. A = convert(AbstractMatrix, A) if cache.isfresh - cacheval = get_cacheval(cache, :CHOLMODFactorization) + cacheval = @get_cacheval(cache, :CHOLMODFactorization) fact = cholesky(A; check = false) if !LinearAlgebra.issuccess(fact) ldlt!(fact, A; check = false) @@ -872,7 +882,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs. cache.isfresh = false end - cache.u .= get_cacheval(cache, :CHOLMODFactorization) \ cache.b + cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end @@ -928,7 +938,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T}; kwargs...) where {P, T} A = cache.A A = convert(AbstractMatrix, A) - fact, ipiv = get_cacheval(cache, :RFLUFactorization) + fact, ipiv = @get_cacheval(cache, :RFLUFactorization) if cache.isfresh if length(ipiv) != min(size(A)...) ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)) @@ -937,7 +947,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T}; cache.cacheval = (fact, ipiv) cache.isfresh = false end - y = ldiv!(cache.u, get_cacheval(cache, :RFLUFactorization)[1], cache.b) + y = ldiv!(cache.u, @get_cacheval(cache, :RFLUFactorization)[1], cache.b) SciMLBase.build_linear_solution(alg, y, nothing, cache) end @@ -1025,10 +1035,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; cache.isfresh = false end if A isa SparseMatrixCSC - cache.u .= get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b) + cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b) y = cache.u else - y = ldiv!(cache.u, get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b) + y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b) end SciMLBase.build_linear_solution(alg, y, nothing, cache) end @@ -1072,7 +1082,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalBunchKaufmanFactorizati cache.cacheval = fact cache.isfresh = false end - y = ldiv!(cache.u, get_cacheval(cache, :NormalBunchKaufmanFactorization), A' * cache.b) + y = ldiv!(cache.u, @get_cacheval(cache, :NormalBunchKaufmanFactorization), A' * cache.b) SciMLBase.build_linear_solution(alg, y, nothing, cache) end @@ -1131,7 +1141,7 @@ end function SciMLBase.solve!(cache::LinearCache, alg::FastLUFactorization; kwargs...) A = cache.A A = convert(AbstractMatrix, A) - ws_and_fact = get_cacheval(cache, :FastLUFactorization) + ws_and_fact = @get_cacheval(cache, :FastLUFactorization) if cache.isfresh # we will fail here if A is a different *size* than in a previous version of the same cache. # it may instead be desirable to resize the workspace. @@ -1201,7 +1211,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P}; kwargs...) where {P} A = cache.A A = convert(AbstractMatrix, A) - ws_and_fact = get_cacheval(cache, :FastQRFactorization) + ws_and_fact = @get_cacheval(cache, :FastQRFactorization) if cache.isfresh # we will fail here if A is a different *size* than in a previous version of the same cache. # it may instead be desirable to resize the workspace. @@ -1281,7 +1291,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs A = cache.A if cache.isfresh if cache.cacheval !== nothing && alg.reuse_symbolic - fact = sparspaklu!(get_cacheval(cache, :SparspakFactorization), + fact = sparspaklu!(@get_cacheval(cache, :SparspakFactorization), SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))) else @@ -1291,6 +1301,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs cache.cacheval = fact cache.isfresh = false end - y = ldiv!(cache.u, get_cacheval(cache, :SparspakFactorization), cache.b) + y = ldiv!(cache.u, @get_cacheval(cache, :SparspakFactorization), cache.b) SciMLBase.build_linear_solution(alg, y, nothing, cache) end diff --git a/test/default_algs.jl b/test/default_algs.jl index 0d42293f6..0735976e1 100644 --- a/test/default_algs.jl +++ b/test/default_algs.jl @@ -1,4 +1,4 @@ -using LinearSolve, LinearAlgebra, SparseArrays, Test +using LinearSolve, LinearAlgebra, SparseArrays, Test, JET @test LinearSolve.defaultalg(nothing, zeros(3)).alg === LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization @test LinearSolve.defaultalg(nothing, zeros(50)).alg === @@ -22,6 +22,19 @@ using LinearSolve, LinearAlgebra, SparseArrays, Test A = rand(4, 4) b = rand(4) prob = LinearProblem(A, b) + JET.@test_opt init(prob, nothing) + JET.@test_opt solve(prob, LUFactorization()) + JET.@test_opt solve(prob, GenericLUFactorization()) + JET.@test_opt solve(prob, QRFactorization()) + JET.@test_opt solve(prob, DiagonalFactorization()) + #JET.@test_opt solve(prob, SVDFactorization()) + #JET.@test_opt solve(prob, KrylovJL_GMRES()) + + prob = LinearProblem(sparse(A), b) + #JET.@test_opt solve(prob, UMFPACKFactorization()) + #JET.@test_opt solve(prob, KLUFactorization()) + #JET.@test_opt solve(prob, SparspakFactorization()) + #JET.@test_opt solve(prob) @inferred solve(prob) @inferred init(prob, nothing) end