From 09a6384bb7484e7264196d33f5ec9d30268ac387 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Dec 2023 10:04:19 -0500 Subject: [PATCH] Add inference check --- src/common.jl | 4 +--- src/iterative_wrappers.jl | 4 +++- test/static_arrays.jl | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/common.jl b/src/common.jl index 765382f81..3f59b3da2 100644 --- a/src/common.jl +++ b/src/common.jl @@ -124,9 +124,7 @@ function __init_u0_from_Ab(A, b) fill!(u0, false) return u0 end -function __init_u0_from_Ab(A::SMatrix{S1, S2}, b) where {S1, S2} - return zeros(SVector{S2, eltype(b)}) -end +__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)}) function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index b37571cb5..294cfe7f1 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -284,8 +284,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) # Copy the solution to the allocated output vector cacheval = @get_cacheval(cache, :KrylovJL_GMRES) - if cache.u !== cacheval.x + if cache.u !== cacheval.x && ArrayInterface.can_setindex(cache.u) cache.u .= cacheval.x + else + cache.u = convert(typeof(cache.u), cacheval.x) end return SciMLBase.build_linear_solution(alg, cache.u, resid, cache; diff --git a/test/static_arrays.jl b/test/static_arrays.jl index 6fc614e67..55158947e 100644 --- a/test/static_arrays.jl +++ b/test/static_arrays.jl @@ -6,6 +6,7 @@ b = SVector{5}(rand(5)) for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(), KrylovJL_GMRES()) sol = solve(LinearProblem(A, b), alg) + @inferred solve(LinearProblem(A, b), alg) @test norm(A * sol .- b) < 1e-10 end @@ -13,6 +14,7 @@ A = SMatrix{7, 5}(rand(7, 5)) b = SVector{7}(rand(7)) for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @inferred solve(LinearProblem(A, b), alg) @test_nowarn solve(LinearProblem(A, b), alg) end @@ -20,5 +22,6 @@ A = SMatrix{5, 7}(rand(5, 7)) b = SVector{5}(rand(5)) for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @inferred solve(LinearProblem(A, b), alg) @test_nowarn solve(LinearProblem(A, b), alg) end