Skip to content

Commit

Permalink
block_minres is almost functional!
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Nov 1, 2024
1 parent c18c377 commit ad85d45
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 93 deletions.
18 changes: 12 additions & 6 deletions src/block_krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ mutable struct BlockMinresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
p :: Int
ΔX :: SM
X :: SM
W :: SM
P :: SM
Q :: SM
C :: SM
D :: SM
τ :: SV
Φ :: SM
tmp :: SM
Vₖ₋₁ :: SM
Vₖ :: SM
wₖ₋₂ :: SM
wₖ₋₁ :: SM
Hₖ₋₂ :: SM
Hₖ₋₁ :: SM
τₖ₋₂ :: SV
τₖ₋₁ :: SV
warm_start :: Bool
stats :: SimpleStats{T}
end
Expand All @@ -45,19 +48,22 @@ function BlockMinresSolver(m, n, p, SV, SM)
T = real(FC)
ΔX = SM(undef, 0, 0)
X = SM(undef, n, p)
W = SM(undef, n, p)
P = SM(undef, 0, 0)
Q = SM(undef, 0, 0)
Q = SM(undef, n, p)
C = SM(undef, p, p)
D = SM(undef, 2p, p)
τ = SV(undef, p)
Φ = SM(undef, p, p)
tmp = C isa Matrix ? SM(undef, 0, 0) : SM(undef, p, p)
Vₖ₋₁ = SM(undef, n, p)
Vₖ = SM(undef, n, p)
wₖ₋₂ = SM(undef, n, p)
wₖ₋₁ = SM(undef, n, p)
Hₖ₋₂ = SM(undef, 2p, p)
Hₖ₋₁ = SM(undef, 2p, p)
τₖ₋₂ = SV(undef, p)
τₖ₋₁ = SV(undef, p)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, τ, tmp, Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, false, stats)
solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, P, Q, C, D, Φ, tmp, Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, Hₖ₋₂, Hₖ₋₁, τₖ₋₂, τₖ₋₁, false, stats)
return solver
end

Expand Down
172 changes: 101 additions & 71 deletions src/block_minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,32 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his

# Set up workspace.
Vₖ₋₁, Vₖ = solver.Vₖ₋₁, solver.Vₖ
ΔX, X, W, Z = solver.ΔX, solver.X, solver.W, solver.Z
C, D, R, H, τ, stats = solver.C, solver.D, solver.R, solver.H, solver.τ, solver.stats
ΔX, X, Q, C = solver.ΔX, solver.X, solver.Q, solver.C
D, Φ, stats = solver.D, solver.Φ, solver.stats
wₖ₋₂, wₖ₋₁ = solver.wₖ₋₂, solver.wₖ₋₁
Hₖ₋₂, Hₖ₋₁ = solver.Hₖ₋₂, solver.Hₖ₋₁
τₖ₋₂, τₖ₋₁ = solver.τₖ₋₂, solver.τₖ₋₁
warm_start = solver.warm_start
RNorms = stats.residuals
reset!(stats)
R₀ = warm_start ? W : B

# Temporary buffers -- should be stored in the solver
Ψₖ₋₁ = zeros(p, p)
Ψₖ = zeros(p, p)
Ωₖ = zeros(p, p)
Ψₖ₊₁ = zeros(p, p)
Πₖ₋₂ = zeros(p, p)
Γbarₖ₋₁ = zeros(p, p)
Γₖ₋₁ = zeros(p, p)
Λbarₖ = zeros(p, p)
Λₖ = zeros(p, p)

# Define the blocks D1 and D2
D1 = view(D, 1:p, :)
D2 = view(D, p+1:2p, :)
trans = FC <: AbstractFloat ? 'T' : 'C'
Φbarₖ = Φₖ = Φbarₖ₊₁ = Φ

# Coefficients for mul!
α = -one(FC)
Expand All @@ -128,8 +142,8 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his

# Initial residual R₀.
if warm_start
mul!(W, A, ΔX)
W .= B .- W
mul!(Q, A, ΔX)
Q .= B .- Q
end
RNorm = norm(R₀) # ‖R₀‖_F
history && push!(RNorms, RNorm)
Expand All @@ -139,7 +153,7 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his

ε = atol + rtol * RNorm
(verbose > 0) && @printf(iostream, "%5s %7s %5s\n", "k", "‖Rₖ‖", "timer")
kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, RNorm, ktimer(start_time))
kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, RNorm, start_time |> ktimer)

# Stopping criterion
status = "unknown"
Expand All @@ -148,131 +162,147 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
user_requested_exit = false
overtimed = false

# Initial Ψ₁ and V₁
τ = τₖ₋₂
copyto!(Vₖ, R₀)
if C isa Matrix
householder!(Vₖ, Φbarₖ, τ)
else
householder!(Vₖ, Φbarₖ, τ, solver.tmp)
end

while !(solved || tired || user_requested_exit || overtimed)
# Update iteration index.
iter = iter + 1

# Initial Ψ₁ and V₁
copyto!(V, R₀)
if C isa Matrix
householder!(V, Z, τ)
else
householder!(V, Z, τ, solver.tmp)
end

# Continue the block-Lanczos process.
mul!(W, A, V) # Q ← AVₖ
for i = 1 : inner_iter
mul!(Ω, Vₖ', W) # Ωₖ = Vₖᴴ * Q
(iter 2) && mul!(Q, Vₖ₋₁, Ψ') # Q ← Q - Vₖ₋₁ * Ψₖᴴ
mul!(Q, Vₖ, Ω, α, β) # Q = Q - Vₖ * Ωₖ
end

# Vₖ₊₁ and Ψₖ₊₁ are stored in Q and C.
if C isa Matrix
householder!(Q, C, τ)
else
householder!(Q, C, τ, solver.tmp)
end
mul!(Q, A, Vₖ) # Q ← AVₖ
mul!(Ωₖ, Vₖ', Q) # Ωₖ = Vₖᴴ * Q
(iter 2) && mul!(Q, Vₖ₋₁, Ψₖ', α, β) # Q ← Q - Vₖ₋₁ * Ψₖᴴ
mul!(Q, Vₖ, Ωₖ, α, β) # Q = Q - Vₖ * Ωₖ

# Update the QR factorization of Tₖ₊₁.ₖ = Qₖ [ Rₖ ].
# [ Oᵀ ]
#
# [ Ω₁ Ψ₂ᴴ 0 • • • 0 ] [ Λ₁ Γ₁ Π₁ 0 • • 0 ]
# [ Ψ₂ Ω₂ • • • ] [ 0 Λ₂ Γ₂ • • • ]
# [ 0 • • • • • ] [ • • Λ₃ • • • • ]
# [ • • • • • • • ] = Qₖ [ • • • • • 0 ]
# [ • • • • • 0 ] [ • • • • Πₖ₋₂]
# [ • • • • Ψₖᴴ ] [ • • • Γₖ₋₁]
# [ • • Ψₖ Ωₖ ] [ 0 • • • • 0 Λₖ ]
# [ 0 • • • • 0 Ψₖ₊₁] [ 0 • • • • • 0 ]
# [ Ω₁ Ψ₂ᴴ 0 • • • 0 ] [ Λ₁ Γ₁ Π₁ 0 • • 0 ]
# [ Ψ₂ Ω₂ • • • ] [ 0 Λ₂ Γ₂ • • • ]
# [ 0 • • • • • ] [ • • Λ₃ • • • • ]
# [ • • • • • • • ] = Qₖ [ • • • • • 0 ]
# [ • • • • • 0 ] [ • • • • Πₖ₋₂]
# [ • • • • Ψₖᴴ ] [ • • • Γₖ₋₁]
# [ • • Ψₖ Ωₖ ] [ 0 • • • • 0 Λₖ ]
# [ 0 • • • • 0 Ψₖ₊₁] [ 0 • • • • • 0 ]
#
# If k = 1, we don't have any previous reflection.
# If k = 2, we apply the last reflection.
# If k ≥ 3, we only apply the two previous reflections.

# Apply previous Householder reflections Θₖ₋₂.
if k 3
D1 .= Rₖ₋₂.
D2 .= Rₖ₋₁.
kormqr!('L', trans, H[i-2], τ[i-2], D)
Rₖ₋₂. .= D1
Rₖ₋₁. .= D2
if iter 3
D1 .= zero(T)
D2 .= Ψₖ₋₁'
kormqr!('L', trans, Hₖ₋₂, τₖ₋₂, D)
Πₖ₋₂ .= D1
Γbarₖ₋₁ .= D2
end

# Apply previous Householder reflections Θₖ₋₁.
if k 2
D1 .= Rₖ₋₁.
D2 .= Rₖ.
kormqr!('L', trans, H[i-1], τ[i-1], D)
Rₖ₋₁.ₖ .= D1
Rₖ.ₖ .= D2
if iter 2
(iter == 2) && (Γbarₖ₋₁ .= Ψₖ₋₁')
D1 .= Γbarₖ₋₁
D2 .= Ωₖ
kormqr!('L', trans, Hₖ₋₁, τₖ₋₁, D)
Γₖ₋₁ .= D1
Λbarₖ .= D2
end

# Vₖ₊₁ and Ψₖ₊₁ are stored in Vₖ₋₁ and C.
τ = τₖ₋₂
copyto!(Vₖ₋₁, Q)
if C isa Matrix
householder!(Vₖ₋₁, C, τ)
else
householder!(Vₖ₋₁, C, τ, solver.tmp)
end

# Compute and apply current Householder reflection θₖ.
H[inner_iter][1:p,:] .= Rₖ.
H[inner_iter][p+1:2p,:] .= C
Ψₖ₊₁ = C
Hₖ = Hₖ₋₂
τₖ = τₖ₋₂
(iter == 1) && (Λbarₖ .= Ωₖ)
Hₖ[1:p,:] .= Λbarₖ
Hₖ[p+1:2p,:] .= Ψₖ₊₁
if C isa Matrix
householder!(H[i], Rₖ.ₖ, τ[i], compact=true)
householder!(Hₖ, Ψₖ₊₁, τₖ, compact=true)
else
householder!(H[i], Rₖ.ₖ, τ[i], solver.tmp, compact=true)
householder!(Hₖ, Ψₖ₊₁, τₖ, solver.tmp, compact=true)
end
Λₖ .= view(Hₖ, 1:p, 1:p)

# Update Zₖ = (Qₖ)ᴴΨ₁E₁ = (Φ₁, ..., Φₖ, Φbarₖ₊₁)
D1 .= Φbarₖ
D2 .= zero(FC)
kormqr!('L', trans, H[i], τ[i], D)
Φₖ = D1
kormqr!('L', trans, Hₖ, τₖ, D)
Φₖ .= D1

# Compute the directions Wₖ, the last columns of Wₖ = Vₖ(Rₖ)⁻¹ ⟷ (Rₖ)ᵀ(Wₖ)ᵀ = (Vₖ)ᵀ
# R₁₁w₁ = v₁
# (Λ₁)ᵀw₁ = v₁
if iter == 1
wₖ = wₖ₋₁
wₖ .+= vₖ
ldiv!(LowerTriangular(R₁₁), wₖ)
wₖ .+= Vₖ
ldiv!(LowerTriangular(Λₖ |> transpose), transpose(wₖ))
end
# R₂₂w₂ = (v₂ - R₂₁w₁)
# (Λ₂)ᵀw₂ = v₂ - (Γ₁)ᵀw₁
if iter == 2
wₖ = wₖ₋₂
wₖ .-= R₂₁ * wₖ₋₁
wₖ .+= vₖ
ldiv!(LowerTriangular(R₂₁), wₖ)
transpose(wₖ) .-= transpose(Γₖ₋₁) * transpose(wₖ₋₁)
wₖ .+= Vₖ
ldiv!(LowerTriangular(Λₖ |> transpose), transpose(wₖ))
end
# Rₖₖwₖ = (vₖ - Rₖₖ₋₁wₖ₋₁ - Rₖₖ₋₂wₖ₋₂)
# (Λₖ)ᵀwₖ = vₖ - (Γₖ₋₁)ᵀwₖ₋₁ - (Πₖ₋₂)ᵀwₖ₋₂
if iter 3
lmul!(UpperTriangular(Rₖₖ₋₂), wₖ₋₂)
wₖ₋₂ .= (wₖ₋₂ * Πₖ₋₂)
# lmul!(transpose(Πₖ₋₂), transpose(wₖ₋₂))
wₖ = wₖ₋₂
wₖ .-= Rₖₖ₋₁ * wₖ₋₁
wₖ .+= vₖ
ldiv!(LowerTriangular(Rₖₖ), wₖ)
transpose(wₖ) .-= transpose(Γₖ₋₁) * transpose(wₖ₋₁)
wₖ .+= Vₖ
ldiv!(LowerTriangular(Λₖ |> transpose), transpose(wₖ))
end

# Update Xₖ = VₖYₖ = WₖZₖ
# Xₖ = Xₖ₋₁ + Φₖ * wₖ
mul!(X, Φₖ, wₖ, γ, β)
# Xₖ = Xₖ₋₁ + wₖ * Φₖ
mul!(X, wₖ, Φₖ, γ, β)

# Update residual norm estimate.
# ‖ M(B - AXₖ) ‖_F = ‖Φbarₖ₊₁‖_F
C .= D2
RNorm = norm(C)
Φbarₖ₊₁ .= D2
RNorm = norm(Φbarₖ₊₁)
history && push!(RNorms, RNorm)

# Compute vₖ and vₖ₊₁
copyto!(Vₖ₋₁, Vₖ) # vₖ₋₁ ← vₖ
copyto!(Vₖ, Q) # vₖ ← vₖ₊₁
copyto!(Vₖ, Q) # vₖ ← vₖ₊₁

# Update directions for X
if iter 2
@kswap!(wₖ₋₂, wₖ₋₁)
end

# Update other variables...
if iter 2
@kswap!(Hₖ₋₂, Hₖ₋₁)
@kswap!(τₖ₋₂, τₖ₋₁)
copyto!(Ψₖ₋₁, Ψₖ)
end
copyto!(Ψₖ, Ψₖ₊₁)

# Update stopping criterion.
user_requested_exit = callback(solver) :: Bool
solved = RNorm ε
tired = iter itmax
timer = time_ns() - start_time
overtimed = timer > timemax_ns
kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, RNorm, ktimer(start_time))
kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, RNorm, start_time |> ktimer)
end
(verbose > 0) && @printf(iostream, "\n")

Expand All @@ -289,7 +319,7 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
# Update stats
stats.niter = iter
stats.solved = solved
stats.timer = ktimer(start_time)
stats.timer = start_time |> ktimer
stats.status = status
return solver
end
Expand Down
4 changes: 2 additions & 2 deletions src/krylov_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
@eval begin
function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $workspace(A, b)
solver = $workspace(A, B)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
$(krylov!)(solver, $(args...); $(kwargs...))
Expand All @@ -223,7 +223,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
if !isempty($optargs)
function $(krylov)($(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $workspace(A, b)
solver = $workspace(A, B)
warm_start!(solver, $(optargs...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/amd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ include("gpu.jl")
Krylov.kcopy!(n, y, x)
end

@testset "kswap -- $FC" begin
@testset "kswap! -- $FC" begin
Krylov.@kswap!(x, y)
end

Expand Down
22 changes: 12 additions & 10 deletions test/gpu/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
using SparseArrays, Random, Test
using LinearAlgebra, Krylov, KernelAbstractions

@kernel function copy_triangle_kernel!(dest, src)
i, j = @index(Global, NTuple)
if j >= i
@inbounds dest[i, j] = src[i, j]
if VERSION < v"1.11"
@kernel function copy_triangle_kernel!(dest, src)
i, j = @index(Global, NTuple)
if j >= i
@inbounds dest[i, j] = src[i, j]
end
end
end

function Krylov.copy_triangle(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, k::Int) where FC <: Krylov.FloatOrComplex
backend = get_backend(Q)
ndrange = (k, k)
copy_triangle_kernel!(backend)(R, Q; ndrange=ndrange)
KernelAbstractions.synchronize(backend)
function Krylov.copy_triangle(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, k::Int) where FC <: Krylov.FloatOrComplex
backend = get_backend(Q)
ndrange = (k, k)
copy_triangle_kernel!(backend)(R, Q; ndrange=ndrange)
KernelAbstractions.synchronize(backend)
end
end

Random.seed!(666)
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/intel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ include("gpu.jl")
Krylov.kcopy!(n, y, x)
end

@testset "kswap -- $FC" begin
@testset "kswap! -- $FC" begin
Krylov.@kswap!(x, y)
end

Expand Down
Loading

0 comments on commit ad85d45

Please sign in to comment.