Skip to content

Commit

Permalink
Fix handling of AbstractSparseMatrixCSC
Browse files Browse the repository at this point in the history
* `SciMLBase.init` did check only for `SparseMatrixCSC`, thus leading to  dispatches for AbstractSparseMatrixCSC with awful timings.
* Now, if A (resp. b) is an `AbstractSparseMatrixCSC`, instead of making a `copy` (or even `deepcopy`), a `SparseMatrixCSC` is constructed using `size`, `getcolptr`, `rowvals` and `nonzeros`.
  • Loading branch information
j-fu committed Aug 18, 2024
1 parent 920eba0 commit c8175a9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,

A = if alias_A || A isa SMatrix
A
elseif A isa Array || A isa SparseMatrixCSC
elseif A isa Array
copy(A)
elseif A isa AbstractSparseMatrixCSC
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))
else
deepcopy(A)
end
Expand All @@ -168,8 +170,10 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Array(b) # the solution to a linear solve will always be dense!
elseif alias_b || b isa SVector
b
elseif b isa Array || b isa SparseMatrixCSC
elseif b isa Array
copy(b)
elseif b isa AbstractSparseMatrixCSC
SparseMatrixCSC(size(b)..., getcolptr(b), rowvals(b), nonzeros(b))
else
deepcopy(b)
end
Expand Down
20 changes: 20 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,23 @@ using BlockDiagonals

@test solve(prob1, SimpleGMRES(; blocksize = 2)).u solve(prob2, SimpleGMRES()).u
end

@testset "AbstractSparseMatrixCSC" begin
struct MySparseMatrixCSC{Tv, Ti} <: SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}
csc::SparseMatrixCSC{Tv, Ti}
end

Base.size(m::MySparseMatrixCSC) = size(m.csc)
SparseArrays.getcolptr(m::MySparseMatrixCSC) = SparseArrays.getcolptr(m.csc)
SparseArrays.rowvals(m::MySparseMatrixCSC) = SparseArrays.rowvals(m.csc)
SparseArrays.nonzeros(m::MySparseMatrixCSC) = SparseArrays.nonzeros(m.csc)

N = 10_000
A = spdiagm(1 => -ones(N - 1), 0 => fill(10.0, N), -1 => -ones(N - 1))
u0 = ones(size(A, 2))
b = A * u0
B = MySparseMatrixCSC(A)
pr = LinearProblem(B, b)
@time "solve MySparseMatrixCSC" u=solve(pr)
@test norm(u - u0, Inf) < 1.0e-13
end

0 comments on commit c8175a9

Please sign in to comment.