Skip to content

Commit

Permalink
Merge pull request #99 from SciML/genericlu
Browse files Browse the repository at this point in the history
add GenericLUFactorizations
  • Loading branch information
ChrisRackauckas authored Jan 22, 2022
2 parents 991a1da + b16a674 commit 7ed66bd
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 19 deletions.
5 changes: 4 additions & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2021 Jonathan <[email protected]> and contributors
Copyright (c) 2021 SciML and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand All @@ -19,3 +19,6 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

SimpleLU.jl is derived from https://github.com/JuliaGNI/SimpleSolvers.jl under
an MIT license.
4 changes: 3 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false

include("common.jl")
include("factorization.jl")
include("simplelu.jl")
include("iterative_wrappers.jl")
include("preconditioners.jl")
include("default.jl")
Expand All @@ -45,7 +46,8 @@ const IS_OPENBLAS = Ref(true)
isopenblas() = IS_OPENBLAS[]

export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
RFLUFactorization, UMFPACKFactorization, KLUFactorization
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
UMFPACKFactorization, KLUFactorization
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES
Expand Down
49 changes: 33 additions & 16 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ function defaultalg(A,b)
# whether MKL or OpenBLAS is being used
if (A === nothing && !isgpu(b)) || A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b) && (length(b) <= 100 ||
(isopenblas() && length(b) <= 500)
)
alg = RFLUFactorization()
ArrayInterface.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
alg = RFLUFactorization()
else
alg = LUFactorization()
end
else
alg = LUFactorization()
end
Expand Down Expand Up @@ -58,12 +62,19 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if A isa Matrix
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
(isopenblas() && size(A,1) <= 500)
)
alg = RFLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
b = cache.b
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
alg = RFLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
else
alg = LUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
end
else
alg = LUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
Expand Down Expand Up @@ -110,12 +121,18 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if A isa Matrix
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
ArrayInterface.can_setindex(b) && (size(A,1) <= 100 ||
(isopenblas() && size(A,1) <= 500)
)
alg = RFLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
alg = RFLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
else
alg = LUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
end
else
alg = LUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
Expand Down
21 changes: 20 additions & 1 deletion src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ struct LUFactorization{P} <: AbstractFactorization
pivot::P
end

struct GenericLUFactorization{P} <: AbstractFactorization
pivot::P
end

function LUFactorization()
pivot = @static if VERSION < v"1.7beta"
Val(true)
Expand All @@ -34,6 +38,15 @@ function LUFactorization()
LUFactorization(pivot)
end

function GenericLUFactorization()
pivot = @static if VERSION < v"1.7beta"
Val(true)
else
RowMaximum()
end
GenericLUFactorization(pivot)
end

function do_factorization(alg::LUFactorization, A, b, u)
A = convert(AbstractMatrix,A)
if A isa SparseMatrixCSC
Expand All @@ -44,7 +57,13 @@ function do_factorization(alg::LUFactorization, A, b, u)
return fact
end

init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
function do_factorization(alg::GenericLUFactorization, A, b, u)
A = convert(AbstractMatrix,A)
fact = LinearAlgebra.generic_lufact!(A, alg.pivot)
return fact
end

init_cacheval(alg::Union{LUFactorization,GenericLUFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))

# This could be a GenericFactorization perhaps?
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
Expand Down
130 changes: 130 additions & 0 deletions src/simplelu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
## From https://github.com/JuliaGNI/SimpleSolvers.jl/blob/master/src/linear/lu_solver.jl

mutable struct LUSolver{T}
n::Int
A::Matrix{T}
b::Vector{T}
x::Vector{T}
pivots::Vector{Int}
perms::Vector{Int}
info::Int

LUSolver{T}(n) where {T} = new(n, zeros(T, n, n), zeros(T, n), zeros(T, n), zeros(Int, n), zeros(Int, n), 0)
end

function LUSolver(A::Matrix{T}) where {T}
n = LinearAlgebra.checksquare(A)
lu = LUSolver{eltype(A)}(n)
lu.A .= A
lu
end

function LUSolver(A::Matrix{T}, b::Vector{T}) where {T}
n = LinearAlgebra.checksquare(A)
@assert n == length(b)
lu = LUSolver{eltype(A)}(n)
lu.A .= A
lu.b .= b
lu
end

function simplelu_factorize!(lu::LUSolver{T}, pivot=true) where {T}
A = lu.A

begin
@inbounds for i in eachindex(lu.perms)
lu.perms[i] = i
end

@inbounds for k = 1:lu.n
# find index max
kp = k
if pivot
amax = real(zero(T))
for i = k:lu.n
absi = abs(A[i,k])
if absi > amax
kp = i
amax = absi
end
end
end
lu.pivots[k] = kp
lu.perms[k], lu.perms[kp] = lu.perms[kp], lu.perms[k]

if A[kp,k] != 0
if k != kp
# Interchange
for i = 1:lu.n
tmp = A[k,i]
A[k,i] = A[kp,i]
A[kp,i] = tmp
end
end
# Scale first column
Akkinv = inv(A[k,k])
for i = k+1:lu.n
A[i,k] *= Akkinv
end
elseif lu.info == 0
lu.info = k
end
# Update the rest
for j = k+1:lu.n
for i = k+1:lu.n
A[i,j] -= A[i,k]*A[k,j]
end
end
end

lu.info
end
end

function simplelu_solve!(lu::LUSolver{T}) where {T}
@inbounds for i = 1:lu.n
lu.x[i] = lu.b[lu.perms[i]]
end

@inbounds for i = 2:lu.n
s = zero(T)
for j = 1:i-1
s += lu.A[i,j] * lu.x[j]
end
lu.x[i] -= s
end

lu.x[lu.n] /= lu.A[lu.n,lu.n]
@inbounds for i = lu.n-1:-1:1
s = zero(T)
for j = i+1:lu.n
s += lu.A[i,j] * lu.x[j]
end
lu.x[i] -= s
lu.x[i] /= lu.A[i,i]
end

copyto!(lu.b,lu.x)

lu.x
end

### Wrapper

struct SimpleLUFactorization <: AbstractFactorization
pivot::Bool
SimpleLUFactorization(pivot=true) = new(pivot)
end

function SciMLBase.solve(cache::LinearCache, alg::SimpleLUFactorization; kwargs...)
if cache.isfresh
cache.cacheval.A = cache.A
simplelu_factorize!(cache.cacheval, alg.pivot)
end
cache.cacheval.b = cache.b
cache.cacheval.x = cache.u
y = simplelu_solve!(cache.cacheval)
SciMLBase.build_linear_solution(alg,y,nothing,cache)
end

init_cacheval(alg::SimpleLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = LUSolver(convert(AbstractMatrix,A))

0 comments on commit 7ed66bd

Please sign in to comment.