Skip to content

Commit

Permalink
Merge pull request #318 from SciML/inference
Browse files Browse the repository at this point in the history
add JET tests and rely less on constant prop
  • Loading branch information
ChrisRackauckas authored Jun 2, 2023
2 parents 208cbc6 + ab451fe commit 6d678a9
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 73 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.1.0"
version = "2.1.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -60,6 +60,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Expand All @@ -69,7 +70,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"]

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
25 changes: 24 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,29 @@ _isidentity_struct(::SciMLBase.DiffEqIdentity) = true

const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)

EnumX.@enumx DefaultAlgorithmChoice begin
LUFactorization
QRFactorization
DiagonalFactorization
DirectLdiv!
SparspakFactorization
KLUFactorization
UMFPACKFactorization
KrylovJL_GMRES
GenericLUFactorization
RFLUFactorization
LDLtFactorization
BunchKaufmanFactorization
CHOLMODFactorization
SVDFactorization
CholeskyFactorization
NormalCholeskyFactorization
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
alg::DefaultAlgorithmChoice.T
end

include("common.jl")
include("factorization.jl")
include("simplelu.jl")
Expand All @@ -74,7 +97,7 @@ include("deprecated.jl")
cache.cacheval = fact
cache.isfresh = false
end
y = _ldiv!(cache.u, get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
cache.b)

#=
Expand Down
2 changes: 1 addition & 1 deletion src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
verbose::Bool = false,
Pl = IdentityOperator(size(prob.A)[1]),
Pr = IdentityOperator(size(prob.A)[2]),
assumptions = OperatorAssumptions(Val(issquare(prob.A))),
assumptions = OperatorAssumptions(issquare(prob.A)),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down
52 changes: 0 additions & 52 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
EnumX.@enumx DefaultAlgorithmChoice begin
LUFactorization
QRFactorization
DiagonalFactorization
DirectLdiv!
SparspakFactorization
KLUFactorization
UMFPACKFactorization
KrylovJL_GMRES
GenericLUFactorization
RFLUFactorization
LDLtFactorization
BunchKaufmanFactorization
CHOLMODFactorization
SVDFactorization
CholeskyFactorization
NormalCholeskyFactorization
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
alg::DefaultAlgorithmChoice.T
end

mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16}
LUFactorization::T1
Expand Down Expand Up @@ -313,35 +290,6 @@ cache.cacheval = NamedTuple(LUFactorization = cache of LUFactorization, ...)
Expr(:call, :DefaultLinearSolverInit, caches...)
end

"""
if algsym === :LUFactorization
cache.cacheval.LUFactorization = ...
else
...
end
"""
@generated function get_cacheval(cache::LinearCache, algsym::Symbol)
ex = :()
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
ex = if ex == :()
Expr(:elseif, :(algsym === $(Meta.quot(alg))),
:(getfield(cache.cacheval, $(Meta.quot(alg)))))
else
Expr(:elseif, :(algsym === $(Meta.quot(alg))),
:(getfield(cache.cacheval, $(Meta.quot(alg)))), ex)
end
end
ex = Expr(:if, ex.args...)

quote
if cache.alg isa DefaultLinearSolver
$ex
else
cache.cacheval
end
end
end

function defaultalg_symbol(::Type{T}) where {T}
Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end])
end
Expand Down
42 changes: 26 additions & 16 deletions src/factorization.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
macro get_cacheval(cache, algsym)
quote
if $(esc(cache)).alg isa DefaultLinearSolver
getfield($(esc(cache)).cacheval, $algsym)
else
$(esc(cache)).cacheval
end
end
end

_ldiv!(x, A, b) = ldiv!(x, A, b)

function _ldiv!(x::Vector, A::Factorization, b::Vector)
Expand Down Expand Up @@ -712,11 +722,11 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
cache.cacheval.colptr &&
SuiteSparse.decrement(SparseArrays.getrowval(A)) ==
get_cacheval(cache, :UMFPACKFactorization).rowval)
@get_cacheval(cache, :UMFPACKFactorization).rowval)
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
else
fact = lu!(get_cacheval(cache, :UMFPACKFactorization),
fact = lu!(@get_cacheval(cache, :UMFPACKFactorization),
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
end
Expand All @@ -727,7 +737,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
cache.isfresh = false
end

y = ldiv!(cache.u, get_cacheval(cache, :UMFPACKFactorization), cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :UMFPACKFactorization), cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

Expand Down Expand Up @@ -782,7 +792,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
A = convert(AbstractMatrix, A)

if cache.isfresh
cacheval = get_cacheval(cache, :KLUFactorization)
cacheval = @get_cacheval(cache, :KLUFactorization)
if cacheval !== nothing && alg.reuse_symbolic
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
cacheval.colptr &&
Expand Down Expand Up @@ -811,7 +821,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
cache.isfresh = false
end

y = ldiv!(cache.u, get_cacheval(cache, :KLUFactorization), cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :KLUFactorization), cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

Expand Down Expand Up @@ -863,7 +873,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs.
A = convert(AbstractMatrix, A)

if cache.isfresh
cacheval = get_cacheval(cache, :CHOLMODFactorization)
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
fact = cholesky(A; check = false)
if !LinearAlgebra.issuccess(fact)
ldlt!(fact, A; check = false)
Expand All @@ -872,7 +882,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs.
cache.isfresh = false
end

cache.u .= get_cacheval(cache, :CHOLMODFactorization) \ cache.b
cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

Expand Down Expand Up @@ -928,7 +938,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T};
kwargs...) where {P, T}
A = cache.A
A = convert(AbstractMatrix, A)
fact, ipiv = get_cacheval(cache, :RFLUFactorization)
fact, ipiv = @get_cacheval(cache, :RFLUFactorization)
if cache.isfresh
if length(ipiv) != min(size(A)...)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
Expand All @@ -937,7 +947,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T};
cache.cacheval = (fact, ipiv)
cache.isfresh = false
end
y = ldiv!(cache.u, get_cacheval(cache, :RFLUFactorization)[1], cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :RFLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

Expand Down Expand Up @@ -1025,10 +1035,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
cache.isfresh = false
end
if A isa SparseMatrixCSC
cache.u .= get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
y = cache.u
else
y = ldiv!(cache.u, get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
end
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end
Expand Down Expand Up @@ -1072,7 +1082,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalBunchKaufmanFactorizati
cache.cacheval = fact
cache.isfresh = false
end
y = ldiv!(cache.u, get_cacheval(cache, :NormalBunchKaufmanFactorization), A' * cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :NormalBunchKaufmanFactorization), A' * cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

Expand Down Expand Up @@ -1131,7 +1141,7 @@ end
function SciMLBase.solve!(cache::LinearCache, alg::FastLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
ws_and_fact = get_cacheval(cache, :FastLUFactorization)
ws_and_fact = @get_cacheval(cache, :FastLUFactorization)
if cache.isfresh
# we will fail here if A is a different *size* than in a previous version of the same cache.
# it may instead be desirable to resize the workspace.
Expand Down Expand Up @@ -1201,7 +1211,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P};
kwargs...) where {P}
A = cache.A
A = convert(AbstractMatrix, A)
ws_and_fact = get_cacheval(cache, :FastQRFactorization)
ws_and_fact = @get_cacheval(cache, :FastQRFactorization)
if cache.isfresh
# we will fail here if A is a different *size* than in a previous version of the same cache.
# it may instead be desirable to resize the workspace.
Expand Down Expand Up @@ -1281,7 +1291,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs
A = cache.A
if cache.isfresh
if cache.cacheval !== nothing && alg.reuse_symbolic
fact = sparspaklu!(get_cacheval(cache, :SparspakFactorization),
fact = sparspaklu!(@get_cacheval(cache, :SparspakFactorization),
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
else
Expand All @@ -1291,6 +1301,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs
cache.cacheval = fact
cache.isfresh = false
end
y = ldiv!(cache.u, get_cacheval(cache, :SparspakFactorization), cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :SparspakFactorization), cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end
15 changes: 14 additions & 1 deletion test/default_algs.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve, LinearAlgebra, SparseArrays, Test
using LinearSolve, LinearAlgebra, SparseArrays, Test, JET
@test LinearSolve.defaultalg(nothing, zeros(3)).alg ===
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
@test LinearSolve.defaultalg(nothing, zeros(50)).alg ===
Expand All @@ -22,6 +22,19 @@ using LinearSolve, LinearAlgebra, SparseArrays, Test
A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)
JET.@test_opt init(prob, nothing)
JET.@test_opt solve(prob, LUFactorization())
JET.@test_opt solve(prob, GenericLUFactorization())
JET.@test_opt solve(prob, QRFactorization())
JET.@test_opt solve(prob, DiagonalFactorization())
#JET.@test_opt solve(prob, SVDFactorization())
#JET.@test_opt solve(prob, KrylovJL_GMRES())

prob = LinearProblem(sparse(A), b)
#JET.@test_opt solve(prob, UMFPACKFactorization())
#JET.@test_opt solve(prob, KLUFactorization())
#JET.@test_opt solve(prob, SparspakFactorization())
#JET.@test_opt solve(prob)
@inferred solve(prob)
@inferred init(prob, nothing)
end

0 comments on commit 6d678a9

Please sign in to comment.