Skip to content

Commit

Permalink
Merge branch 'main' into bug/fix-tol-eltype
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Oct 29, 2023
2 parents 6315b55 + 1c679c6 commit f4232c0
Show file tree
Hide file tree
Showing 15 changed files with 296 additions and 37 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ jobs:
group: 'LinearSolveHYPRE'
- version: '1'
group: 'LinearSolvePardiso'
- version: '1'
group: 'LinearSolveBandedMatrices'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.11.1"
version = "2.13.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -67,7 +67,7 @@ EnzymeCore = "0.5, 0.6"
FastLapackInterface = "1, 2"
GPUArraysCore = "0.1"
HYPRE = "1.4.0"
IterativeSolvers = "0.9.2"
IterativeSolvers = "0.9.3"
KLU = "0.3.0, 0.4"
KernelAbstractions = "0.9"
Krylov = "0.9"
Expand Down Expand Up @@ -107,4 +107,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices"]
1 change: 1 addition & 0 deletions docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ IterativeSolversJL_CG
IterativeSolversJL_GMRES
IterativeSolversJL_BICGSTAB
IterativeSolversJL_MINRES
IterativeSolversJL_IDRS
IterativeSolversJL
```

Expand Down
10 changes: 8 additions & 2 deletions ext/LinearSolveBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ import LinearSolve: defaultalg,
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice

# Defaults for BandedMatrices
function defaultalg(A::BandedMatrix, b, ::OperatorAssumptions)
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions)
if oa.issq
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
elseif LinearSolve.is_underdetermined(A)
error("No solver for underdetermined `A::BandedMatrix` is currently implemented!")
else
return DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
end
end

function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions)
Expand Down
22 changes: 12 additions & 10 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i
d_b .= 0
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob_d_A[i]
prob_d_A[i] .+= d_A[i]
d_A[i] .= 0
for (_prob_d_A,_d_A,_prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
if _d_A !== _prob_d_A
_prob_d_A .+= _d_A
_d_A .= 0
end
if d_b !== prob_d_b[i]
prob_d_b[i] .+= d_b[i]
d_b[i] .= 0
if _d_b !== _prob_d_b
_prob_d_b .+= _d_b
_d_b .= 0
end
end
end
Expand Down Expand Up @@ -144,13 +144,15 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
_linsolve.cacheval' \ dy
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
_linsolve.cacheval[1]' \ dy
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod
elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob;
solve(invprob, _linearsolve.alg;
abstol = _linsolve.val.abstol,
reltol = _linsolve.val.reltol,
verbose = _linsolve.val.verbose)
elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
Expand All @@ -163,4 +165,4 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
return (nothing,)
end

end
end
17 changes: 16 additions & 1 deletion ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ function LinearSolve.IterativeSolversJL_GMRES(args...; kwargs...)
generate_iterator = IterativeSolvers.gmres_iterable!,
kwargs...)
end
function LinearSolve.IterativeSolversJL_IDRS(args...; kwargs...)
IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.idrs_iterable!,
kwargs...)
end

function LinearSolve.IterativeSolversJL_BICGSTAB(args...; kwargs...)
IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.bicgstabl_iterator!,
Expand All @@ -47,6 +53,7 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
restart = (alg.gmres_restart == 0) ? min(20, size(A, 1)) : alg.gmres_restart
s = :idrs_s in keys(alg.kwargs) ? alg.kwargs.idrs_s : 4 # shadow space

kwargs = (abstol = abstol, reltol = reltol, maxiter = maxiters,
alg.kwargs...)
Expand All @@ -59,6 +66,14 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.idrs_iterable!
!!LinearSolve._isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
history = IterativeSolvers.ConvergenceHistory(partial=true)
history[:abstol] = abstol
history[:reltol] = reltol
IterativeSolvers.idrs_iterable!(history, u, A, b, s, Pl, abstol, reltol, maxiters;
alg.kwargs...)
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
!!LinearSolve._isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
Expand Down Expand Up @@ -95,7 +110,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...
end
cache.verbose && println()

resid = cache.cacheval.residual
resid = cache.cacheval isa IterativeSolvers.IDRSIterable ? cache.cacheval.R : cache.cacheval.residual
if resid isa IterativeSolvers.Residual
resid = resid.current
end
Expand Down
31 changes: 30 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Util
is_underdetermined(x) = false
is_underdetermined(A::AbstractMatrix) = size(A, 1) < size(A, 2)
is_underdetermined(A::AbstractSciMLOperator) = size(A, 1) < size(A, 2)

_isidentity_struct(A) = false
_isidentity_struct::Number) = isone(λ)
Expand Down Expand Up @@ -96,6 +99,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
NormalCholeskyFactorization
AppleAccelerateLUFactorization
MKLLUFactorization
QRFactorizationPivoted
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand Down Expand Up @@ -143,6 +147,31 @@ end
include("factorization_sparse.jl")
end

# Solver Specific Traits
## Needs Square Matrix
"""
needs_square_A(alg)
Returns `true` if the algorithm requires a square matrix.
"""
needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg!
needs_square_A(alg::SciMLLinearSolveAlgorithm) = true
for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization,
:NormalBunchKaufmanFactorization)
@eval needs_square_A(::$(alg)) = false
end
for kralg in (Krylov.lsmr!, Krylov.craigmr!)
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
end
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
:MKLLUFactorization, :MetalLUFactorization)
@eval needs_square_A(::$(alg)) = true
end

const IS_OPENBLAS = Ref(true)
isopenblas() = IS_OPENBLAS[]

Expand Down Expand Up @@ -188,7 +217,7 @@ export LinearSolveFunction, DirectLdiv!
export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
KrylovJL_BICGSTAB, KrylovJL_LSMR, KrylovJL_CRAIGMR,
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES, IterativeSolversJL_IDRS,
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES

export SimpleGMRES
Expand Down
91 changes: 86 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17, T18}
T13, T14, T15, T16, T17, T18, T19}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
QRFactorizationPivoted::T19
end

# Legacy fallback
Expand Down Expand Up @@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.MKLLUFactorization
Expand Down Expand Up @@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif assump.condition === OperatorCondition.IllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted
else
DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted
else
DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
else
Expand Down Expand Up @@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
NormalCholeskyFactorization()
elseif alg === :AppleAccelerateLUFactorization
AppleAccelerateLUFactorization()
elseif alg === :QRFactorizationPivoted
@static if VERSION v"1.7beta"
QRFactorization(ColumnNorm())
else
QRFactorization(Val(true))
end
else
error("Algorithm choice symbol $alg not allowed in the default")
end
Expand Down Expand Up @@ -311,6 +328,12 @@ function defaultalg_symbol(::Type{T}) where {T}
end
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization

@static if VERSION >= v"1.7"
defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted
else
defaultalg_symbol(::Type{<:QRFactorization{Val{true}}}) = :QRFactorizationPivoted
end

"""
if alg.alg === DefaultAlgorithmChoice.LUFactorization
SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))
Expand Down Expand Up @@ -339,3 +362,61 @@ end
end
ex = Expr(:if, ex.args...)
end

"""
```
elseif DefaultAlgorithmChoice.LUFactorization === cache.alg
(cache.cacheval.LUFactorization)' \\ dy
else
...
end
```
"""
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
ex = :()
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization,
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
DefaultAlgorithmChoice.RFLUFactorization))
quote
getproperty(cache.cacheval,$(Meta.quot(alg)))[1]' \ dy
end
elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
DefaultAlgorithmChoice.QRFactorization,
DefaultAlgorithmChoice.KLUFactorization,
DefaultAlgorithmChoice.UMFPACKFactorization,
DefaultAlgorithmChoice.LDLtFactorization,
DefaultAlgorithmChoice.SparspakFactorization,
DefaultAlgorithmChoice.BunchKaufmanFactorization,
DefaultAlgorithmChoice.CHOLMODFactorization,
DefaultAlgorithmChoice.SVDFactorization,
DefaultAlgorithmChoice.CholeskyFactorization,
DefaultAlgorithmChoice.NormalCholeskyFactorization,
DefaultAlgorithmChoice.QRFactorizationPivoted,
DefaultAlgorithmChoice.GenericLUFactorization))
quote
getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy
end
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,))
quote
invprob = LinearSolve.LinearProblem(transpose(cache.A), dy)
solve(invprob, cache.alg;
abstol = cache.val.abstol,
reltol = cache.val.reltol,
verbose = cache.val.verbose)
end
else
quote
error("Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
end

ex = if ex == :()
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex,
:(error("Algorithm Choice not Allowed")))
else
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, ex)
end
end
ex = Expr(:if, ex.args...)
end
15 changes: 15 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,21 @@ A wrapper over the IterativeSolvers.jl GMRES.
"""
function IterativeSolversJL_GMRES end

"""
```julia
IterativeSolversJL_IDRS(args...; Pl = nothing, kwargs...)
```
A wrapper over the IterativeSolvers.jl IDR(S).
!!! note
Using this solver requires adding the package IterativeSolvers.jl, i.e. `using IterativeSolvers`
"""
function IterativeSolversJL_IDRS end

"""
```julia
IterativeSolversJL_BICGSTAB(args...; Pl = nothing, Pr = nothing, kwargs...)
Expand Down
10 changes: 10 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ function QRFactorization(inplace = true)
QRFactorization(pivot, 16, inplace)
end

@static if VERSION v"1.7beta"
function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
else
function QRFactorization(pivot::Val, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
end

function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
Expand Down
Loading

0 comments on commit f4232c0

Please sign in to comment.