From 3d12bd8fb6ac830ce07731cdd7f71e52b9ba0b79 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 22 Jan 2022 05:09:06 -0500 Subject: [PATCH 1/4] add GenericLUFactorizations --- src/LinearSolve.jl | 2 +- src/factorization.jl | 21 ++++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 695f5ab96..6825fb055 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -45,7 +45,7 @@ const IS_OPENBLAS = Ref(true) isopenblas() = IS_OPENBLAS[] export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, - RFLUFactorization, UMFPACKFactorization, KLUFactorization + GenericLUFactorization, RFLUFactorization, UMFPACKFactorization, KLUFactorization export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES, IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES, IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES diff --git a/src/factorization.jl b/src/factorization.jl index 7b144e05f..d0d3b3cfc 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -25,6 +25,10 @@ struct LUFactorization{P} <: AbstractFactorization pivot::P end +struct GenericLUFactorization{P} <: AbstractFactorization + pivot::P +end + function LUFactorization() pivot = @static if VERSION < v"1.7beta" Val(true) @@ -34,6 +38,15 @@ function LUFactorization() LUFactorization(pivot) end +function GenericLUFactorization() + pivot = @static if VERSION < v"1.7beta" + Val(true) + else + RowMaximum() + end + GenericLUFactorization(pivot) +end + function do_factorization(alg::LUFactorization, A, b, u) A = convert(AbstractMatrix,A) if A isa SparseMatrixCSC @@ -44,7 +57,13 @@ function do_factorization(alg::LUFactorization, A, b, u) return fact end -init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) +function do_factorization(alg::GenericLUFactorization, A, b, u) + A = convert(AbstractMatrix,A) + fact = LinearAlgebra.generic_lufact!(A, alg.pivot) + return fact +end + +init_cacheval(alg::Union{LUFactorization,GenericLUFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) # This could be a GenericFactorization perhaps? Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization From f336a163e2629aa910f1a319fb01048959151d19 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 22 Jan 2022 05:36:50 -0500 Subject: [PATCH 2/4] improve defaults and add simplelu --- LICENSE | 5 +- src/LinearSolve.jl | 4 +- src/default.jl | 48 +++++++++++------ src/simplelu.jl | 132 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 18 deletions(-) create mode 100644 src/simplelu.jl diff --git a/LICENSE b/LICENSE index 5e8b648f5..5feb19e1a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 Jonathan and contributors +Copyright (c) 2021 SciML and contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -19,3 +19,6 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +SimpleLU.jl is derived from https://github.com/JuliaGNI/SimpleSolvers.jl under +an MIT license. diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 6825fb055..76b0f78a8 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -36,6 +36,7 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false include("common.jl") include("factorization.jl") +include("simplelu.jl") include("iterative_wrappers.jl") include("preconditioners.jl") include("default.jl") @@ -45,7 +46,8 @@ const IS_OPENBLAS = Ref(true) isopenblas() = IS_OPENBLAS[] export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, - GenericLUFactorization, RFLUFactorization, UMFPACKFactorization, KLUFactorization + GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, + UMFPACKFactorization, KLUFactorization export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES, IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES, IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES diff --git a/src/default.jl b/src/default.jl index 3d97c7b82..3d0682c30 100644 --- a/src/default.jl +++ b/src/default.jl @@ -11,10 +11,14 @@ function defaultalg(A,b) # whether MKL or OpenBLAS is being used if (A === nothing && !isgpu(b)) || A isa Matrix if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && - ArrayInterface.can_setindex(b) && (length(b) <= 100 || - (isopenblas() && length(b) <= 500) - ) - alg = RFLUFactorization() + ArrayInterface.can_setindex(b) + if length(b) <= 10 + alg = GenericLUFactorization() + elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) + alg = RFLUFactorization() + else + alg = LUFactorization() + end else alg = LUFactorization() end @@ -58,12 +62,18 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing, # it makes sense according to the benchmarks, which is dependent on # whether MKL or OpenBLAS is being used if A isa Matrix - if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} && - ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 || - (isopenblas() && size(A,1) <= 500) - ) - alg = RFLUFactorization() - SciMLBase.solve(cache, alg, args...; kwargs...) + if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && + ArrayInterface.can_setindex(b) + if length(b) <= 10 + alg = GenericLUFactorization() + SciMLBase.solve(cache, alg, args...; kwargs...) + elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) + alg = RFLUFactorization() + SciMLBase.solve(cache, alg, args...; kwargs...) + else + alg = LUFactorization() + SciMLBase.solve(cache, alg, args...; kwargs...) + end else alg = LUFactorization() SciMLBase.solve(cache, alg, args...; kwargs...) @@ -110,12 +120,18 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, # it makes sense according to the benchmarks, which is dependent on # whether MKL or OpenBLAS is being used if A isa Matrix - if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} && - ArrayInterface.can_setindex(b) && (size(A,1) <= 100 || - (isopenblas() && size(A,1) <= 500) - ) - alg = RFLUFactorization() - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) + if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && + ArrayInterface.can_setindex(b) + if length(b) <= 10 + alg = GenericLUFactorization() + init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) + elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) + alg = RFLUFactorization() + init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) + else + alg = LUFactorization() + init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) + end else alg = LUFactorization() init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) diff --git a/src/simplelu.jl b/src/simplelu.jl new file mode 100644 index 000000000..b89af413a --- /dev/null +++ b/src/simplelu.jl @@ -0,0 +1,132 @@ +## From https://github.com/JuliaGNI/SimpleSolvers.jl/blob/master/src/linear/lu_solver.jl + +mutable struct LUSolver{T} + n::Int + A::Matrix{T} + b::Vector{T} + x::Vector{T} + pivots::Vector{Int} + perms::Vector{Int} + info::Int + + LUSolver{T}(n) where {T} = new(n, zeros(T, n, n), zeros(T, n), zeros(T, n), zeros(Int, n), zeros(Int, n), 0) +end + +function LUSolver(A::Matrix{T}) where {T} + n = LinearAlgebra.checksquare(A) + lu = LUSolver{eltype(A)}(n) + lu.A .= A + lu +end + +function LUSolver(A::Matrix{T}, b::Vector{T}) where {T} + n = LinearAlgebra.checksquare(A) + @assert n == length(b) + lu = LUSolver{eltype(A)}(n) + lu.A .= A + lu.b .= b + lu +end + +function simplelu_factorize!(lu::LUSolver{T}, pivot=true) where {T} + A = lu.A + + begin + @inbounds for i in eachindex(lu.perms) + lu.perms[i] = i + end + + @inbounds for k = 1:lu.n + # find index max + kp = k + if pivot + amax = real(zero(T)) + for i = k:lu.n + absi = abs(A[i,k]) + if absi > amax + kp = i + amax = absi + end + end + end + lu.pivots[k] = kp + lu.perms[k], lu.perms[kp] = lu.perms[kp], lu.perms[k] + + if A[kp,k] != 0 + if k != kp + # Interchange + for i = 1:lu.n + tmp = A[k,i] + A[k,i] = A[kp,i] + A[kp,i] = tmp + end + end + # Scale first column + Akkinv = inv(A[k,k]) + for i = k+1:lu.n + A[i,k] *= Akkinv + end + elseif lu.info == 0 + lu.info = k + end + # Update the rest + for j = k+1:lu.n + for i = k+1:lu.n + A[i,j] -= A[i,k]*A[k,j] + end + end + end + + lu.info + end +end + +function simplelu_solve!(lu::LUSolver{T}) where {T} + local s::T + + @inbounds for i = 1:lu.n + lu.x[i] = lu.b[lu.perms[i]] + end + + @inbounds for i = 2:lu.n + s = 0 + for j = 1:i-1 + s += lu.A[i,j] * lu.x[j] + end + lu.x[i] -= s + end + + lu.x[lu.n] /= lu.A[lu.n,lu.n] + @inbounds for i = lu.n-1:-1:1 + s = 0 + for j = i+1:lu.n + s += lu.A[i,j] * lu.x[j] + end + lu.x[i] -= s + lu.x[i] /= lu.A[i,i] + end + + lu.b .= lu.x + + lu.x +end + +### Wrapper + +struct SimpleLUFactorization <: AbstractFactorization + pivot::Bool + SimpleLUFactorization(pivot=true) = new(pivot) +end + +function SciMLBase.solve(cache::LinearCache, alg::SimpleLUFactorization; kwargs...) + if cache.isfresh + cache.cacheval.A = cache.A + simplelu_factorize!(cache.cacheval, alg.pivot) + end + cache.cacheval.b = cache.b + cache.cacheval.x = cache.u + y = simplelu_solve!(cache.cacheval) + SciMLBase.build_linear_solution(alg,y,nothing,cache) +end + +init_cacheval(alg::SimpleLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = LUSolver(convert(AbstractMatrix,A)) From 1996cae9a2d00ff832219785050618eb1b1eaef2 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 22 Jan 2022 06:10:48 -0500 Subject: [PATCH 3/4] simplify the simplelu code a bit --- src/simplelu.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/simplelu.jl b/src/simplelu.jl index b89af413a..3d372791a 100644 --- a/src/simplelu.jl +++ b/src/simplelu.jl @@ -82,14 +82,12 @@ function simplelu_factorize!(lu::LUSolver{T}, pivot=true) where {T} end function simplelu_solve!(lu::LUSolver{T}) where {T} - local s::T - @inbounds for i = 1:lu.n lu.x[i] = lu.b[lu.perms[i]] end @inbounds for i = 2:lu.n - s = 0 + s = zero(T) for j = 1:i-1 s += lu.A[i,j] * lu.x[j] end @@ -98,7 +96,7 @@ function simplelu_solve!(lu::LUSolver{T}) where {T} lu.x[lu.n] /= lu.A[lu.n,lu.n] @inbounds for i = lu.n-1:-1:1 - s = 0 + s = zero(T) for j = i+1:lu.n s += lu.A[i,j] * lu.x[j] end @@ -106,7 +104,7 @@ function simplelu_solve!(lu::LUSolver{T}) where {T} lu.x[i] /= lu.A[i,i] end - lu.b .= lu.x + copyto!(lu.b,lu.x) lu.x end From b16a674a934b6cd6e91418d4b9e8ff8a5b6307fe Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 22 Jan 2022 06:19:01 -0500 Subject: [PATCH 4/4] fix typo --- src/default.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/default.jl b/src/default.jl index 3d0682c30..98ca563f7 100644 --- a/src/default.jl +++ b/src/default.jl @@ -62,6 +62,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing, # it makes sense according to the benchmarks, which is dependent on # whether MKL or OpenBLAS is being used if A isa Matrix + b = cache.b if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) && ArrayInterface.can_setindex(b) if length(b) <= 10