Skip to content

Commit

Permalink
Fix check checking on previous Julia versions (#332)
Browse files Browse the repository at this point in the history
* Fix check checking on previous Julia versions

* flip check

* fix precompile for posdef

* fix test definition for symmetric

* Fix symmetric test

* only precompile sparse if v1.9

* Fix tests

* only do check on sparse for v1.9 up

* fix cholesky pivots

* try different pivots... again..

* ugh

* wrong version

* require some version bumping

We just need to drop the older versions soon, this is crazy.

* Diagonal is v1.9

* another versioning fix

* a few more versionings...

* one more

* only JET on v1.9

* format
  • Loading branch information
ChrisRackauckas authored Jun 19, 2023
1 parent c0bed7b commit b7fe5c0
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 73 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
version:
- '1'
- '1.6'
- '1.7'
- '1.8'
include:
- version: '1'
group: 'LinearSolveHYPRE'
Expand Down
14 changes: 8 additions & 6 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ end
end
end

PrecompileTools.@compile_workload begin
A = sprand(4, 4, 0.3) + I
b = rand(4)
prob = LinearProblem(A, b)
sol = solve(prob) # in case sparspak is used as default
sol = solve(prob, SparspakFactorization())
@static if VERSION > v"1.9-"
PrecompileTools.@compile_workload begin
A = sprand(4, 4, 0.3) + I
b = rand(4)
prob = LinearProblem(A * A', b)
sol = solve(prob) # in case sparspak is used as default
sol = solve(prob, SparspakFactorization())
end
end

export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
Expand Down
154 changes: 108 additions & 46 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
nothing
end

@static if VERSION < v"1.7-"
@static if VERSION < v"1.9-"
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
A::Union{Diagonal, SymTridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
Expand Down Expand Up @@ -188,7 +188,7 @@ function init_cacheval(alg::QRFactorization, A::AbstractSciMLOperator, b, u, Pl,
nothing
end

@static if VERSION < v"1.7-"
@static if VERSION < v"1.9-"
function init_cacheval(alg::QRFactorization,
A::Union{Diagonal, SymTridiagonal, Tridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
Expand Down Expand Up @@ -220,7 +220,7 @@ end

function CholeskyFactorization(; pivot = nothing, tol = 0.0, shift = 0.0, perm = nothing)
if pivot === nothing
pivot = @static if VERSION < v"1.7beta"
pivot = @static if VERSION < v"1.8beta"
Val(false)
else
NoPivot()
Expand All @@ -229,16 +229,30 @@ function CholeskyFactorization(; pivot = nothing, tol = 0.0, shift = 0.0, perm =
CholeskyFactorization(pivot, 16, shift, perm)
end

function do_factorization(alg::CholeskyFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if A isa SparseMatrixCSC
fact = cholesky!(A; shift = alg.shift, check = false, perm = alg.perm)
elseif alg.pivot === Val(false) || alg.pivot === NoPivot()
fact = cholesky!(A, alg.pivot; check = false)
else
fact = cholesky!(A, alg.pivot; tol = alg.tol, check = false)
@static if VERSION > v"1.8-"
function do_factorization(alg::CholeskyFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if A isa SparseMatrixCSC
fact = cholesky!(A; shift = alg.shift, check = false, perm = alg.perm)
elseif alg.pivot === Val(false) || alg.pivot === NoPivot()
fact = cholesky!(A, alg.pivot; check = false)
else
fact = cholesky!(A, alg.pivot; tol = alg.tol, check = false)
end
return fact
end
else
function do_factorization(alg::CholeskyFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if A isa SparseMatrixCSC
fact = cholesky!(A; shift = alg.shift, perm = alg.perm)
elseif alg.pivot === Val(false) || alg.pivot === NoPivot()
fact = cholesky!(A, alg.pivot)
else
fact = cholesky!(A, alg.pivot; tol = alg.tol)
end
return fact
end
return fact
end

function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
Expand All @@ -247,7 +261,7 @@ function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
end

@static if VERSION < v"1.7beta"
@static if VERSION < v"1.8beta"
cholpivot = Val(false)
else
cholpivot = NoPivot()
Expand All @@ -268,7 +282,7 @@ function init_cacheval(alg::CholeskyFactorization,
nothing
end

@static if VERSION < v"1.7beta"
@static if VERSION < v"1.9beta"
function init_cacheval(alg::CholeskyFactorization,
A::Union{SymTridiagonal, Tridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
Expand Down Expand Up @@ -361,7 +375,7 @@ function init_cacheval(alg::SVDFactorization, A, b, u, Pl, Pr,
nothing
end

@static if VERSION < v"1.7-"
@static if VERSION < v"1.9-"
function init_cacheval(alg::SVDFactorization,
A::Union{Diagonal, SymTridiagonal, Tridiagonal}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
Expand Down Expand Up @@ -852,22 +866,42 @@ function init_cacheval(alg::CHOLMODFactorization,
PREALLOCATED_CHOLMOD
end

function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
@static if VERSION > v"1.8-"
function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)

if cache.isfresh
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
fact = cholesky(A; check = false)
if !LinearAlgebra.issuccess(fact)
ldlt!(fact, A; check = false)
if cache.isfresh
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
fact = cholesky(A; check = false)
if !LinearAlgebra.issuccess(fact)
ldlt!(fact, A; check = false)
end
cache.cacheval = fact
cache.isfresh = false
end
cache.cacheval = fact
cache.isfresh = false

cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end
else
function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)

if cache.isfresh
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
fact = cholesky(A)
if !LinearAlgebra.issuccess(fact)
ldlt!(fact, A)
end
cache.cacheval = fact
cache.isfresh = false
end

cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end
end

## RFLUFactorization
Expand Down Expand Up @@ -909,7 +943,7 @@ function init_cacheval(alg::RFLUFactorization,
nothing, nothing
end

@static if VERSION < v"1.7-"
@static if VERSION < v"1.9-"
function init_cacheval(alg::RFLUFactorization,
A::Union{Diagonal, SymTridiagonal, Tridiagonal}, b, u, Pl, Pr,
maxiters::Int,
Expand Down Expand Up @@ -954,7 +988,7 @@ end

function NormalCholeskyFactorization(; pivot = nothing)
if pivot === nothing
pivot = @static if VERSION < v"1.7beta"
pivot = @static if VERSION < v"1.8beta"
Val(false)
else
NoPivot()
Expand All @@ -966,7 +1000,7 @@ end
default_alias_A(::NormalCholeskyFactorization, ::Any, ::Any) = true
default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true

@static if VERSION < v"1.7beta"
@static if VERSION < v"1.8beta"
normcholpivot = Val(false)
else
normcholpivot = NoPivot()
Expand Down Expand Up @@ -996,7 +1030,7 @@ function init_cacheval(alg::NormalCholeskyFactorization,
nothing
end

@static if VERSION < v"1.7-"
@static if VERSION < v"1.9-"
function init_cacheval(alg::NormalCholeskyFactorization,
A::Union{Tridiagonal, SymTridiagonal, Adjoint}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
Expand All @@ -1005,26 +1039,54 @@ end
end
end

function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
@static if VERSION > v"1.8-"
function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC
fact = cholesky(Symmetric((A)' * A, :L); check = false)
else
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false)
end
cache.cacheval = fact
cache.isfresh = false
end
if A isa SparseMatrixCSC
fact = cholesky(Symmetric((A)' * A, :L))
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
y = cache.u
else
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot)
y = ldiv!(cache.u,
@get_cacheval(cache, :NormalCholeskyFactorization),
A' * cache.b)
end
cache.cacheval = fact
cache.isfresh = false
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end
if A isa SparseMatrixCSC
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
y = cache.u
else
y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
else
function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC
fact = cholesky(Symmetric((A)' * A, :L))
else
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot)
end
cache.cacheval = fact
cache.isfresh = false
end
if A isa SparseMatrixCSC
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
y = cache.u
else
y = ldiv!(cache.u,
@get_cacheval(cache, :NormalCholeskyFactorization),
A' * cache.b)
end
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

## NormalBunchKaufmanFactorization
Expand Down
30 changes: 17 additions & 13 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
@testset "Default Linear Solver" begin
test_interface(nothing, prob1, prob2)

A1 = prob1.A
A1 = prob1.A * prob1.A'
b1 = prob1.b
x1 = prob1.u0
y = solve(prob1)
Expand All @@ -77,9 +77,11 @@ end
y = solve(_prob)
@test A1 * y b1

_prob = LinearProblem(sparse(A1), b1; u0 = x1)
y = solve(_prob)
@test A1 * y b1
if VERSION > v"1.9-"
_prob = LinearProblem(sparse(A1), b1; u0 = x1)
y = solve(_prob)
@test A1 * y b1
end
end

@testset "UMFPACK Factorization" begin
Expand Down Expand Up @@ -258,19 +260,21 @@ end
end
end

@testset "KrylovKit" begin
kwargs = (; gmres_restart = 5)
for alg in (("Default", KrylovKitJL(kwargs...)),
("CG", KrylovKitJL_CG(kwargs...)),
("GMRES", KrylovKitJL_GMRES(kwargs...)))
@testset "$(alg[1])" begin
test_interface(alg[2], prob1, prob2)
if VERSION > v"1.9-"
@testset "KrylovKit" begin
kwargs = (; gmres_restart = 5)
for alg in (("Default", KrylovKitJL(kwargs...)),
("CG", KrylovKitJL_CG(kwargs...)),
("GMRES", KrylovKitJL_GMRES(kwargs...)))
@testset "$(alg[1])" begin
test_interface(alg[2], prob1, prob2)
end
@test alg[2] isa KrylovKitJL
end
@test alg[2] isa KrylovKitJL
end
end

if VERSION > v"1.7-"
if VERSION > v"1.9-"
@testset "CHOLMOD" begin
# Create a posdef symmetric matrix
A = sprand(100, 100, 0.01)
Expand Down
2 changes: 1 addition & 1 deletion test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ solve(prob)
prob = LinearProblem(sprand(11000, 11000, 0.5), zeros(11000))
solve(prob)

@static if VERSION >= v"v1.7-"
@static if VERSION >= v"v1.9-"
# Test inference
A = rand(4, 4)
b = rand(4)
Expand Down
16 changes: 9 additions & 7 deletions test/resolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ A = Symmetric([1.0 2.0
linsolve.A = A
@test solve!(linsolve).u [1.0, 0.0]

A = Symmetric([1.0 2.0
2.0 1.0])
A = [1.0 2.0
2.0 1.0]
A = Symmetric(A * A')
b = [1.0, 2.0]
prob = LinearProblem(A, b)
linsolve = init(prob, CholeskyFactorization(), alias_A = false, alias_b = false)
@test solve!(linsolve).u [1.0, 0.0]
@test solve!(linsolve).u [1.0, 0.0]
A = Symmetric([1.0 2.0
2.0 1.0])
@test solve!(linsolve).u [-1 / 3, 2 / 3]
@test solve!(linsolve).u [-1 / 3, 2 / 3]
A = [1.0 2.0
2.0 1.0]
A = Symmetric(A * A')
b = [1.0, 2.0]
@test solve!(linsolve).u [1.0, 0.0]
@test solve!(linsolve).u [-1 / 3, 2 / 3]

0 comments on commit b7fe5c0

Please sign in to comment.