From c8175a9109ea70e1b5b6159ab850e483d3b6b9d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Fuhrmann?= Date: Sun, 18 Aug 2024 13:52:44 +0200 Subject: [PATCH] Fix handling of AbstractSparseMatrixCSC * `SciMLBase.init` did check only for `SparseMatrixCSC`, thus leading to dispatches for AbstractSparseMatrixCSC with awful timings. * Now, if A (resp. b) is an `AbstractSparseMatrixCSC`, instead of making a `copy` (or even `deepcopy`), a `SparseMatrixCSC` is constructed using `size`, `getcolptr`, `rowvals` and `nonzeros`. --- src/common.jl | 8 ++++++-- test/basictests.jl | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/common.jl b/src/common.jl index 3c222eca..cf91672a 100644 --- a/src/common.jl +++ b/src/common.jl @@ -158,8 +158,10 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, A = if alias_A || A isa SMatrix A - elseif A isa Array || A isa SparseMatrixCSC + elseif A isa Array copy(A) + elseif A isa AbstractSparseMatrixCSC + SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)) else deepcopy(A) end @@ -168,8 +170,10 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, Array(b) # the solution to a linear solve will always be dense! elseif alias_b || b isa SVector b - elseif b isa Array || b isa SparseMatrixCSC + elseif b isa Array copy(b) + elseif b isa AbstractSparseMatrixCSC + SparseMatrixCSC(size(b)..., getcolptr(b), rowvals(b), nonzeros(b)) else deepcopy(b) end diff --git a/test/basictests.jl b/test/basictests.jl index d27c0a8d..2e69eedc 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -533,3 +533,23 @@ using BlockDiagonals @test solve(prob1, SimpleGMRES(; blocksize = 2)).u ≈ solve(prob2, SimpleGMRES()).u end + +@testset "AbstractSparseMatrixCSC" begin + struct MySparseMatrixCSC{Tv, Ti} <: SparseArrays.AbstractSparseMatrixCSC{Tv, Ti} + csc::SparseMatrixCSC{Tv, Ti} + end + + Base.size(m::MySparseMatrixCSC) = size(m.csc) + SparseArrays.getcolptr(m::MySparseMatrixCSC) = SparseArrays.getcolptr(m.csc) + SparseArrays.rowvals(m::MySparseMatrixCSC) = SparseArrays.rowvals(m.csc) + SparseArrays.nonzeros(m::MySparseMatrixCSC) = SparseArrays.nonzeros(m.csc) + + N = 10_000 + A = spdiagm(1 => -ones(N - 1), 0 => fill(10.0, N), -1 => -ones(N - 1)) + u0 = ones(size(A, 2)) + b = A * u0 + B = MySparseMatrixCSC(A) + pr = LinearProblem(B, b) + @time "solve MySparseMatrixCSC" u=solve(pr) + @test norm(u - u0, Inf) < 1.0e-13 +end