Skip to content

Commit

Permalink
Merge pull request #418 from SciML/simplegmres
Browse files Browse the repository at this point in the history
For SimpleGMRES we need to reinitialize some cache when `b` is set again
  • Loading branch information
ChrisRackauckas authored Nov 6, 2023
2 parents 7ee3fa2 + fc36891 commit d914caa
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.16.0"
version = "2.16.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
8 changes: 8 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,21 @@ end
function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
setfield!(cache, :isfresh, true)
elseif name === :b
# In case there is something that needs to be done when b is updated
update_cacheval!(cache, :b, x)
elseif name === :cacheval && cache.alg isa DefaultLinearSolver
@assert cache.cacheval isa DefaultLinearSolverInit
return setfield!(cache.cacheval, Symbol(cache.alg.alg), x)
end
setfield!(cache, name, x)
end

function update_cacheval!(cache::LinearCache, name::Symbol, x)
return update_cacheval!(cache, cache.cacheval, name, x)
end
update_cacheval!(cache, cacheval, name::Symbol, x) = cacheval

init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing

function SciMLBase.init(prob::LinearProblem, args...; kwargs...)
Expand Down
7 changes: 7 additions & 0 deletions src/simplegmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ end
warm_start::Bool
end

function update_cacheval!(cache::LinearCache, cacheval::SimpleGMRESCache, name::Symbol, x)
(name != :b || cache.isfresh) && return cacheval
vec(cacheval.w) .= vec(x)
fill!(cacheval.x, 0)
return cacheval
end

"""
(c, s, ρ) = _sym_givens(a, b)
Expand Down
14 changes: 8 additions & 6 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using LinearSolve, CUDA, LinearAlgebra, SparseArrays
using Test

CUDA.allowscalar(false)

n = 8
A = Matrix(I, n, n)
b = ones(n)
Expand All @@ -25,19 +27,19 @@ function test_interface(alg, prob1, prob2)
x2 = prob2.u0

y = solve(prob1, alg; cache_kwargs...)
@test A1 * y b1
@test CUDA.@allowscalar(Array(A1 * y) Array(b1))

cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
solve!(cache)
@test A1 * cache.u b1
@test CUDA.@allowscalar(Array(A1 * cache.u) Array(b1))

cache.A = copy(A2)
solve!(cache)
@test A2 * cache.u b1
@test CUDA.@allowscalar(Array(A2 * cache.u) Array(b1))

cache.b = copy(b2)
solve!(cache)
@test A2 * cache.u b2
@test CUDA.@allowscalar(Array(A2 * cache.u) Array(b2))

return
end
Expand All @@ -62,8 +64,8 @@ using BlockDiagonals
A = BlockDiagonal([rand(2, 2) for _ in 1:3]) |> cu
b = rand(size(A, 1)) |> cu

x1 = zero(b)
x2 = zero(b)
x1 = zero(b) |> cu
x2 = zero(b) |> cu
prob1 = LinearProblem(A, b, x1)
prob2 = LinearProblem(A, b, x2)

Expand Down

0 comments on commit d914caa

Please sign in to comment.