Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
tmigot committed Feb 6, 2024
1 parent 62bc634 commit 3f9e200
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 68 deletions.
2 changes: 1 addition & 1 deletion docs/src/storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Each table summarizes the storage requirements of Krylov methods recommended to

| Methods | [`USYMQR`](@ref usymqr) | [`CGLS`](@ref cgls) | [`CG-LANCZOS-SHIFT`](@ref cg_lanczos_shift) | [`CRLS`](@ref crls) | [`LSLQ`](@ref lslq) | [`LSQR`](@ref lsqr) | [`LSMR`](@ref lsmr) |
|:-------:|:-----------------------:|:-------------------:|:-------------------:|:-------------------:|:-------------------:|:-------------------:|
| Storage | $6n + 3m$ | $3n + 2m$ | $3n + 2m + Xp$ | $4n + 3m$ | $4n + 2m$ | $4n + 2m$ | $5n + 2m$ |
| Storage | $6n + 3m$ | $3n + 2m$ | $3n + 2m + 5p + 2np$ | $4n + 3m$ | $4n + 2m$ | $4n + 2m$ | $5n + 2m$ |

#### Adjoint systems

Expand Down
92 changes: 40 additions & 52 deletions src/cgls_lanczos_shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,51 +73,47 @@ See [`CglsLanczosShiftSolver`](@ref) for more details about the `solver`.
"""
function cgls_lanczos_shift! end

def_args_cg_lanczos_shift = (:(A ),
:(b::AbstractVector{FC} ),
:(shifts::AbstractVector{T}))

def_kwargs_cg_lanczos_shift = (:(; M = I ),
:(; ldiv::Bool = false ),
:(; check_curvature::Bool = false),
:(; atol::T = eps(T) ),
:(; rtol::T = eps(T) ),
:(; itmax::Int = 0 ),
:(; timemax::Float64 = Inf ),
:(; verbose::Int = 0 ),
:(; history::Bool = false ),
:(; callback = solver -> false ),
:(; iostream::IO = kstdout ))

def_kwargs_cg_lanczos_shift = mapreduce(extract_parameters, vcat, def_kwargs_cg_lanczos_shift)

args_cg_lanczos_shift = (:A, :b, :shifts)
kwargs_cg_lanczos_shift = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

solver :: CglsLanczosShiftSolver{T,FC,S}, A, b :: AbstractVector{FC}, shifts :: AbstractVector{T};
M=I, atol :: T=eps(T), rtol :: T=eps(T),
itmax :: Int=0, verbose :: Int=0, history :: Bool=false,
callback = solver -> false) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: DenseVector{FC}}
def_args_cgls_lanczos_shift = (:(A ),
:(b::AbstractVector{FC} ),
:(shifts::AbstractVector{T}))

def_kwargs_cgls_lanczos_shift = (:(; M = I ),
:(; ldiv::Bool = false ),
:(; check_curvature::Bool = false),
:(; atol::T = eps(T) ),
:(; rtol::T = eps(T) ),
:(; itmax::Int = 0 ),
:(; timemax::Float64 = Inf ),
:(; verbose::Int = 0 ),
:(; history::Bool = false ),
:(; callback = solver -> false ),
:(; iostream::IO = kstdout ))

def_kwargs_cgls_lanczos_shift = mapreduce(extract_parameters, vcat, def_kwargs_cgls_lanczos_shift)

args_cgls_lanczos_shift = (:A, :b, :shifts)
kwargs_cgls_lanczos_shift = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function cgls_lanczos_shift($(def_args_cg_lanczos_shift...); $(def_kwargs_cg_lanczos_shift...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
function cgls_lanczos_shift($(def_args_cgls_lanczos_shift...); $(def_kwargs_cgls_lanczos_shift...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
nshifts = length(shifts)
solver = CglsLanczosShiftSolver(A, b, nshifts)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cgls_lanczos_shift!(solver, $(args_cg_lanczos_shift...); $(kwargs_cg_lanczos_shift...))
cgls_lanczos_shift!(solver, $(args_cgls_lanczos_shift...); $(kwargs_cgls_lanczos_shift...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cgls_lanczos_shift!(solver :: CgLanczosShiftSolver{T,FC,S}, $(def_args_cgls_lanczos_shift...); $(def_kwargs_cgls_lanczos_shift...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
function cgls_lanczos_shift!(solver :: CglsLanczosShiftSolver{T,FC,S}, $(def_args_cgls_lanczos_shift...); $(def_kwargs_cgls_lanczos_shift...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
start_time = time_ns()
timemax_ns = 1e9 * timemax

m, n = size(A)
(m == solver.m && n == solver.n) || error("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($m, $n)")
length(b) == m || error("Inconsistent problem size")

nshifts = length(shifts)
Expand All @@ -144,9 +140,9 @@ solver :: CglsLanczosShiftSolver{T,FC,S}, A, b :: AbstractVector{FC}, shifts ::
x, p, σ, δhat = solver.x, solver.p, solver.σ, solver.δhat
ω, γ, rNorms, converged = solver.ω, solver.γ, solver.rNorms, solver.converged
not_cv, stats = solver.not_cv, solver.stats
rNorms_history, indefinite, status = stats.residuals, stats.indefinite, stats.status
rNorms_history, status = stats.residuals, stats.status
reset!(stats)
v = solver.v # v = MisI ? Mv : solver.v
v = solver.v

# Initial state.
## Distribute x similarly to shifts.
Expand All @@ -165,14 +161,11 @@ solver :: CglsLanczosShiftSolver{T,FC,S}, A, b :: AbstractVector{FC}, shifts ::
end
end

# Keep track of shifted systems with negative curvature if required.
indefinite .= false

if β == 0
stats.niter = 0
stats.solved = true
stats.timer = ktimer(start_time)
status .= "x = 0 is a zero-residual solution"
status = "x = 0 is a zero-residual solution"
return solver
end

Expand All @@ -184,8 +177,6 @@ solver :: CglsLanczosShiftSolver{T,FC,S}, A, b :: AbstractVector{FC}, shifts ::
# Initialize Lanczos process.
# β₁v₁ = b
@kscal!(n, one(FC) / β, v) # v₁ ← v₁ / β₁
# MisI || @kscal!(n, one(FC) / β, Mv) # Mv₁ ← Mv₁ / β₁
# Mv_prev .= Mv
@kscal!(m, one(FC) / β, u)

# Initialize some constants used in recursions below.
Expand All @@ -212,22 +203,22 @@ solver :: CglsLanczosShiftSolver{T,FC,S}, A, b :: AbstractVector{FC}, shifts ::

solved = !reduce(|, not_cv) # ArNorm ≤ ε
tired = iter itmax
status .= "unknown"
status = "unknown"
user_requested_exit = false
overtimed = false

# Main loop.
while ! (solved || tired || user_requested_exit || overtimed)

# Form next Lanczos vector.
mul!(utilde, A, v) # utildeₖ ← Avₖ
δ = @kdotr(m, utilde, utilde) # δₖ = vₖᵀAᵀAvₖ
@kaxpy!(m, -δ, u, utilde) # uₖ₊₁ = utildeₖ - δₖuₖ - βₖuₖ₋₁
mul!(utilde, A, v) # utildeₖ ← Avₖ
δ = @kdotr(m, utilde, utilde) # δₖ = vₖᵀAᵀAvₖ
@kaxpy!(m, -δ, u, utilde) # uₖ₊₁ = utildeₖ - δₖuₖ - βₖuₖ₋₁
@kaxpy!(m, -β, u_prev, utilde)
mul!(v, A', utilde) # vₖ₊₁ = Aᵀuₖ₊₁
β = sqrt(@kdotr(n, v, v)) # βₖ₊₁ = vₖ₊₁ᵀ M vₖ₊₁
@kscal!(n, one(FC) / β, v) # vₖ₊₁ ← vₖ₊₁ / βₖ₊₁
@kscal!(m, one(FC) / β, utilde) # uₖ₊₁ = uₖ₊₁ / βₖ₊₁
mul!(v, A', utilde) # vₖ₊₁ = Aᵀuₖ₊₁
β = sqrt(@kdotr(n, v, v)) # βₖ₊₁ = vₖ₊₁ᵀ M vₖ₊₁
@kscal!(n, one(FC) / β, v) # vₖ₊₁ ← vₖ₊₁ / βₖ₊₁
@kscal!(m, one(FC) / β, utilde) # uₖ₊₁ = uₖ₊₁ / βₖ₊₁
u_prev .= u
u .= utilde

Expand Down Expand Up @@ -275,18 +266,15 @@ solver :: CglsLanczosShiftSolver{T,FC,S}, A, b :: AbstractVector{FC}, shifts ::
(verbose > 0) && @printf(iostream, "\n")

# Termination status
overtimed && (status = "time limit exceeded")
for i = 1 : nshifts
tired && (stats.status[i] = "maximum number of iterations exceeded")
converged[i] && (stats.status[i] = "solution good enough given atol and rtol")
end
user_requested_exit && (status .= "user-requested exit")
tired && (status = "maximum number of iterations exceeded")
solved && (status = "solution good enough given atol and rtol")
user_requested_exit && (status = "user-requested exit")
overtimed && (status = "time limit exceeded")

# Update stats
# Update stats
stats.niter = iter
stats.solved = solved
stats.timer = ktimer(start_time)
stats.inconsistent .= false
return solver
end
end
11 changes: 7 additions & 4 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,9 @@ The outer constructors
may be used in order to create these vectors.
"""
mutable struct CglsLanczosShiftSolver{T,FC,S} <: KrylovSolver{T,FC,S}
m :: Int
n :: Int
nshifts :: Int
Mv :: S
Mv_prev :: S
Mv_next :: S
Expand All @@ -1205,7 +1208,7 @@ mutable struct CglsLanczosShiftSolver{T,FC,S} <: KrylovSolver{T,FC,S}
rNorms :: Vector{T}
converged :: BitVector
not_cv :: BitVector
stats :: SimpleShiftsStats{T}
stats :: LanczosShiftStats{T}
end

function CglsLanczosShiftSolver(m, n, nshifts, S)
Expand All @@ -1226,8 +1229,8 @@ function CglsLanczosShiftSolver(m, n, nshifts, S)
indefinite = BitVector(undef, nshifts)
converged = BitVector(undef, nshifts)
not_cv = BitVector(undef, nshifts)
stats = SimpleShiftsStats(0, false, zeros(Bool, nshifts), indefinite, [T[] for i = 1 : nshifts], [T[] for i = 1 : nshifts], zeros(T, nshifts), ["unknown" for i = 1 : nshifts])
solver = new{T,FC,S}(Mv, Mv_prev, Mv_next, u, v, x, p, σ, δhat, ω, γ, rNorms, converged, not_cv, stats)
stats = LanczosShiftStats(0, false, Vector{T}[T[] for i = 1 : nshifts], indefinite, T(NaN), T(NaN), 0.0, "unknown")
solver = CglsLanczosShiftSolver{T,FC,S}(m, n, nshifts, Mv, Mv_prev, Mv_next, u, v, x, p, σ, δhat, ω, γ, rNorms, converged, not_cv, stats)
return solver
end

Expand Down Expand Up @@ -1974,7 +1977,7 @@ for (KS, fun, nsol, nA, nAt, warm_start) in [
(:CgSolver , :cg! , 1, 1, 0, true )
(:CgLanczosShiftSolver, :cg_lanczos_shift!, 1, 1, 0, false)
(:CglsSolver , :cgls! , 1, 1, 1, false)
(:CglsLanczosShiftSolver, :cgls_lanczos_shift!, 1, 1, 0, false)
(:CglsLanczosShiftSolver, :cgls_lanczos_shift!, 1, 1, 1, false)
(:CgLanczosSolver , :cg_lanczos! , 1, 1, 0, true )
(:BilqSolver , :bilq! , 1, 1, 1, true )
(:MinresQlpSolver , :minres_qlp! , 1, 1, 0, true )
Expand Down
13 changes: 13 additions & 0 deletions test/callback_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,16 @@ function restarted_gmres_callback_n2(solver::GmresSolver, A, b, stor, N, storage
storage_vec .-= b
return (norm(storage_vec) tol)
end

mutable struct TestCallbackN2LSShifts{T, S, M}
A::M
b::S
shifts::Vector{T}
tol::T
end
TestCallbackN2LSShifts(A, b, shifts; tol = 0.1) = TestCallbackN2LSShifts(A, b, shifts, tol)

function (cb_n2::TestCallbackN2LSShifts)(solver)
r = residuals_ls(cb_n2.A, cb_n2.b, cb_n2.shifts, solver.x)
return all(map(norm, r) .≤ cb_n2.tol)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ include("test_minres_qlp.jl")
include("test_symmlq.jl")
include("test_bilq.jl")
include("test_cgls.jl")

include("test_cgls_lanczos_shift.jl")
include("test_crls.jl")
include("test_cgne.jl")
include("test_crmr.jl")
Expand Down
19 changes: 8 additions & 11 deletions test/test_allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,25 +396,22 @@

@testset "CGLS-LANCZOS-SHIFT" begin
# CGLS-LANCZOS-SHIFT needs:
# - 3 n-vectors: x, p, s
# - 2 m-vectors: r, q

# - 2 n-vectors: Mv, v
# - 3 m-vectors: Mv_prev, Mv_next, u
# - 2 (n*nshifts)-matrices: x, p
# - 5 nshifts-vectors: σ, δhat, ω, γ, rNorms
# - 2 nshifts-bitVector: converged, not_cv
# - 3 nshifts-bitVector: converged, indefinite, not_cv

storage_cgls_lanczos_shift_bytes(m, n) = nbits_FC * (3 * n + 2 * m)
storage_cgls_lanczos_shift_bytes(m, n, nshifts) = nbits_FC * (2 * n + 3 * m + 2 * n * nshifts) + nbits_T * (5 * nshifts) + (3 * nshifts)

expected_cgls_lanczos_shift_bytes = storage_cgls_lanczos_shift_bytes(m, k)
(x, stats) = cgls_lanczos_shift(Ao, b) # warmup
actual_cgls_lanczos_shift_bytes = @allocated cgls_lanczos_shift(Ao, b)
expected_cgls_lanczos_shift_bytes = storage_cgls_lanczos_shift_bytes(m, k, nshifts)
(x, stats) = cgls_lanczos_shift(Ao, b, shifts) # warmup
actual_cgls_lanczos_shift_bytes = @allocated cgls_lanczos_shift(Ao, b, shifts)
@test expected_cgls_lanczos_shift_bytes actual_cgls_lanczos_shift_bytes 1.02 * expected_cgls_lanczos_shift_bytes

solver = CglsSolver(Ao, b)
cgls_lanczos_shift!(solver, Ao, b) # warmup
inplace_cgls_lanczos_shift_bytes = @allocated cgls_lanczos_shift!(solver, Ao, b)
solver = CglsLanczosShiftSolver(Ao, b)
cgls_lanczos_shift!(solver, Ao, b, shifts) # warmup
inplace_cgls_lanczos_shift_bytes = @allocated cgls_lanczos_shift!(solver, Ao, b, shifts)
@test inplace_cgls_lanczos_shift_bytes == 0
end

Expand Down
56 changes: 56 additions & 0 deletions test/test_cgls_lanczos_shift.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
function residuals_ls(A, b, shifts, x)
nshifts = length(shifts)
r = [ (A' * (b - A * x[i]) - shifts[i] * x[i]) for i = 1 : nshifts ]
return r
end

@testset "cgls_lanczos_shift" begin
cgls_lanczos_shift_tol = 1.0e-6

for FC in (Float64, ComplexF64)
@testset "Data Type: $FC" begin

for npower = 1 : 4
(b, A, D, HY, HZ, Acond, rnorm) = test(40, 40, 4, npower, 0) # No regularization.
shifts = [1.0; 2.0; 3.0; 4.0; 5.0; 6.0]
(x, stats) = cgls_lanczos_shift(A, b, shifts)
r = residuals_ls(A, b, shifts, x)
resids = map(norm, r) / norm(A' * b)
@test all(resids .≤ cgls_lanczos_shift_tol)
@test(stats.solved)
end

# Test b == 0
A, b = zero_rhs(FC=FC)
(x, stats) = cgls_lanczos_shift(A, b, shifts)
for xi x
@test norm(xi) == 0
end
@test status == "x = 0 is a zero-residual solution"

#=
# Not implemented
# Test with Jacobi (or diagonal) preconditioner
A, b, M = square_preconditioned(FC=FC)
shifts = [1.0; 2.0; 3.0; 4.0; 5.0; 6.0]
(x, stats) = cgls_lanczos_shift(A, b, shifts, M=M)
r = residuals(A, b, shifts, x)
resids = map(norm, r) / norm(b)
@test(all(resids .≤ cgls_lanczos_shift_tol))
@test(stats.solved)
=#

# test callback function
A, b = symmetric_definite(FC=FC)
shifts = [1.0; 2.0; 3.0; 4.0; 5.0; 6.0]
solver = CglsLanczosShiftSolver(A, b, length(shifts))
tol = 1.0e-1
cb_n2 = TestCallbackN2LSShifts(A, b, shifts, tol = tol)
cgls_lanczos_shift!(solver, A, b, shifts, atol = 0.0, rtol = 0.0, callback = cb_n2)
@test solver.stats.status == "user-requested exit"
@test cb_n2(solver)

@test_throws TypeError cg_lanczos_shift(A, b, shifts, callback = solver -> "string", history = true)
end
end
end

0 comments on commit 3f9e200

Please sign in to comment.