diff --git a/src/common.jl b/src/common.jl index 3c222eca..cf91672a 100644 --- a/src/common.jl +++ b/src/common.jl @@ -158,8 +158,10 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, A = if alias_A || A isa SMatrix A - elseif A isa Array || A isa SparseMatrixCSC + elseif A isa Array copy(A) + elseif A isa AbstractSparseMatrixCSC + SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)) else deepcopy(A) end @@ -168,8 +170,10 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, Array(b) # the solution to a linear solve will always be dense! elseif alias_b || b isa SVector b - elseif b isa Array || b isa SparseMatrixCSC + elseif b isa Array copy(b) + elseif b isa AbstractSparseMatrixCSC + SparseMatrixCSC(size(b)..., getcolptr(b), rowvals(b), nonzeros(b)) else deepcopy(b) end diff --git a/test/basictests.jl b/test/basictests.jl index d27c0a8d..2e69eedc 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -533,3 +533,23 @@ using BlockDiagonals @test solve(prob1, SimpleGMRES(; blocksize = 2)).u ≈ solve(prob2, SimpleGMRES()).u end + +@testset "AbstractSparseMatrixCSC" begin + struct MySparseMatrixCSC{Tv, Ti} <: SparseArrays.AbstractSparseMatrixCSC{Tv, Ti} + csc::SparseMatrixCSC{Tv, Ti} + end + + Base.size(m::MySparseMatrixCSC) = size(m.csc) + SparseArrays.getcolptr(m::MySparseMatrixCSC) = SparseArrays.getcolptr(m.csc) + SparseArrays.rowvals(m::MySparseMatrixCSC) = SparseArrays.rowvals(m.csc) + SparseArrays.nonzeros(m::MySparseMatrixCSC) = SparseArrays.nonzeros(m.csc) + + N = 10_000 + A = spdiagm(1 => -ones(N - 1), 0 => fill(10.0, N), -1 => -ones(N - 1)) + u0 = ones(size(A, 2)) + b = A * u0 + B = MySparseMatrixCSC(A) + pr = LinearProblem(B, b) + @time "solve MySparseMatrixCSC" u=solve(pr) + @test norm(u - u0, Inf) < 1.0e-13 +end