Skip to content

Commit

Permalink
Use 5-arguments mul! in TRICG
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored and abelsiqueira committed Jul 26, 2021
1 parent bfa8573 commit 6d9e9ad
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 37 deletions.
6 changes: 1 addition & 5 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,11 @@ mutable struct TricgSolver{T,S} <: KrylovSolver{T,S}
yₖ :: S
N⁻¹uₖ₋₁ :: S
N⁻¹uₖ :: S
p :: S
gy₂ₖ₋₁ :: S
gy₂ₖ :: S
xₖ :: S
M⁻¹vₖ₋₁ :: S
M⁻¹vₖ :: S
q :: S
gx₂ₖ₋₁ :: S
gx₂ₖ :: S
uₖ :: S
Expand All @@ -485,18 +483,16 @@ mutable struct TricgSolver{T,S} <: KrylovSolver{T,S}
yₖ = S(undef, m)
N⁻¹uₖ₋₁ = S(undef, m)
N⁻¹uₖ = S(undef, m)
p = S(undef, m)
gy₂ₖ₋₁ = S(undef, m)
gy₂ₖ = S(undef, m)
xₖ = S(undef, n)
M⁻¹vₖ₋₁ = S(undef, n)
M⁻¹vₖ = S(undef, n)
q = S(undef, n)
gx₂ₖ₋₁ = S(undef, n)
gx₂ₖ = S(undef, n)
uₖ = S(undef, 0)
vₖ = S(undef, 0)
solver = new{T,S}(yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, uₖ, vₖ)
solver = new{T,S}(yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, gy₂ₖ₋₁, gy₂ₖ, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, gx₂ₖ₋₁, gx₂ₖ, uₖ, vₖ)
return solver
end

Expand Down
55 changes: 26 additions & 29 deletions src/tricg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ function tricg!(solver :: TricgSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst
# Set up workspace.
allocate_if(!MisI, solver, :vₖ, S, m)
allocate_if(!NisI, solver, :uₖ, S, n)
yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p, xₖ, M⁻¹vₖ₋₁ = solver.yₖ, solver.N⁻¹uₖ₋₁, solver.p, solver.N⁻¹uₖ, solver.xₖ, solver.M⁻¹vₖ₋₁
M⁻¹vₖ, q, gy₂ₖ₋₁, gy₂ₖ, gx₂ₖ₋₁, gx₂ₖ = solver.M⁻¹vₖ, solver.q, solver.gy₂ₖ₋₁, solver.gy₂ₖ, solver.gx₂ₖ₋₁, solver.gx₂ₖ
yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, xₖ, M⁻¹vₖ₋₁ = solver.yₖ, solver.N⁻¹uₖ₋₁, solver.N⁻¹uₖ, solver.xₖ, solver.M⁻¹vₖ₋₁
M⁻¹vₖ, gy₂ₖ₋₁, gy₂ₖ, gx₂ₖ₋₁, gx₂ₖ = solver.M⁻¹vₖ, solver.gy₂ₖ₋₁, solver.gy₂ₖ, solver.gx₂ₖ₋₁, solver.gx₂ₖ
vₖ = MisI ? M⁻¹vₖ : solver.vₖ
uₖ = NisI ? N⁻¹uₖ : solver.uₖ
vₖ₊₁ = MisI ? q : vₖ
uₖ₊₁ = NisI ? p : uₖ
vₖ₊₁ = MisI ? M⁻¹vₖ₋₁ : vₖ
uₖ₊₁ = NisI ? N⁻¹uₖ₋₁ : uₖ

# Initial solutions x₀ and y₀.
xₖ .= zero(T)
Expand Down Expand Up @@ -160,22 +160,13 @@ function tricg!(solver :: TricgSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst
# AUₖ = EVₖTₖ + βₖ₊₁Evₖ₊₁(eₖ)ᵀ = EVₖ₊₁Tₖ₊₁.ₖ
# AᵀVₖ = FUₖ(Tₖ)ᵀ + γₖ₊₁Fuₖ₊₁(eₖ)ᵀ = FUₖ₊₁(Tₖ.ₖ₊₁)ᵀ

mul!(q, A , uₖ) # Forms Evₖ₊₁ : q ← Auₖ
mul!(p, Aᵀ, vₖ) # Forms Fuₖ₊₁ : p ← Aᵀvₖ
mul!(M⁻¹vₖ₋₁, A , uₖ, one(T), -γₖ) # Forms Evₖ₊₁ : Evₖ₋₁ ← Auₖ - γₖEvₖ₋₁
mul!(N⁻¹uₖ₋₁, Aᵀ, vₖ, one(T), -βₖ) # Forms Fuₖ₊₁ : Fuₖ₋₁ ← Aᵀvₖ - βₖFuₖ₋₁

if iter 2
@kaxpy!(m, -γₖ, M⁻¹vₖ₋₁, q) # q ← q - γₖ * M⁻¹vₖ₋₁
@kaxpy!(n, -βₖ, N⁻¹uₖ₋₁, p) # p ← p - βₖ * N⁻¹uₖ₋₁
end

αₖ = @kdot(m, vₖ, q) # αₖ = qᵀvₖ
αₖ = @kdot(m, vₖ, M⁻¹vₖ₋₁) # αₖ = (Auₖ- γₖEvₖ₋₁)ᵀvₖ

@kaxpy!(m, -αₖ, M⁻¹vₖ, q) # q ← q - αₖ * M⁻¹vₖ
@kaxpy!(n, -αₖ, N⁻¹uₖ, p) # p ← p - αₖ * N⁻¹uₖ

# Update M⁻¹vₖ₋₁ and N⁻¹uₖ₋₁
@. M⁻¹vₖ₋₁ = M⁻¹vₖ
@. N⁻¹uₖ₋₁ = N⁻¹uₖ
@kaxpy!(m, -αₖ, M⁻¹vₖ, M⁻¹vₖ₋₁) # Evₖ₋₁ ← Evₖ₋₁ - αₖ * Evₖ
@kaxpy!(n, -αₖ, N⁻¹uₖ, N⁻¹uₖ₋₁) # Fuₖ₋₁ ← Fuₖ₋₁ - αₖ * Fuₖ

# Notations : Wₖ = [w₁ ••• wₖ] = [v₁ 0 ••• vₖ 0 ]
# [0 u₁ ••• 0 uₖ]
Expand Down Expand Up @@ -266,25 +257,31 @@ function tricg!(solver :: TricgSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst
@. yₖ += π₂ₖ₋₁ * gy₂ₖ₋₁ + π₂ₖ * gy₂ₖ

# Compute vₖ₊₁ and uₖ₊₁
MisI || mul!(vₖ₊₁, M, q) # βₖ₊₁vₖ₊₁ = MAuₖ - γₖvₖ₋₁ - αₖvₖ
NisI || mul!(uₖ₊₁, N, p) # γₖ₊₁uₖ₊₁ = NAᵀvₖ - βₖuₖ₋₁ - αₖuₖ
MisI || mul!(vₖ₊₁, M, M⁻¹vₖ₋₁) # βₖ₊₁vₖ₊₁ = MAuₖ - γₖvₖ₋₁ - αₖvₖ
NisI || mul!(uₖ₊₁, N, N⁻¹uₖ₋₁) # γₖ₊₁uₖ₊₁ = NAᵀvₖ - βₖuₖ₋₁ - αₖuₖ

βₖ₊₁ = sqrt(@kdot(m, vₖ₊₁, q)) # βₖ₊₁ = ‖vₖ₊₁‖_E
γₖ₊₁ = sqrt(@kdot(n, uₖ₊₁, p)) # γₖ₊₁ = ‖uₖ₊₁‖_F
βₖ₊₁ = sqrt(@kdot(m, vₖ₊₁, M⁻¹vₖ₋₁)) # βₖ₊₁ = ‖vₖ₊₁‖_E
γₖ₊₁ = sqrt(@kdot(n, uₖ₊₁, N⁻¹uₖ₋₁)) # γₖ₊₁ = ‖uₖ₊₁‖_F

if βₖ₊₁ 0
@kscal!(m, one(T) / βₖ₊₁, q)
if βₖ₊₁ zero(T)
@kscal!(m, one(T) / βₖ₊₁, M⁻¹vₖ₋₁)
MisI || @kscal!(m, one(T) / βₖ₊₁, vₖ₊₁)
end

if γₖ₊₁ 0
@kscal!(n, one(T) / γₖ₊₁, p)
if γₖ₊₁ zero(T)
@kscal!(n, one(T) / γₖ₊₁, N⁻¹uₖ₋₁)
NisI || @kscal!(n, one(T) / γₖ₊₁, uₖ₊₁)
end

# Update M⁻¹vₖ and N⁻¹uₖ
@. M⁻¹vₖ = q
@. N⁻¹uₖ = p
# Update N⁻¹uₖ₋₁, M⁻¹vₖ₋₁, N⁻¹uₖ and M⁻¹vₖ.
@kswap(N⁻¹uₖ₋₁, N⁻¹uₖ)
@kswap(M⁻¹vₖ₋₁, M⁻¹vₖ)

# Update pointers impacted by the swaps.
NisI && (uₖ = N⁻¹uₖ)
NisI && (uₖ₊₁ = N⁻¹uₖ₋₁)
MisI && (vₖ = M⁻¹vₖ)
MisI && (vₖ₊₁ = M⁻¹vₖ₋₁)

# Compute ‖rₖ‖² = (γₖ₊₁ζ₂ₖ₋₁)² + (βₖ₊₁ζ₂ₖ)²
rNorm = sqrt((γₖ₊₁ * (π₂ₖ₋₁ - δₖ*π₂ₖ))^2 + (βₖ₊₁ * π₂ₖ)^2)
Expand Down
6 changes: 3 additions & 3 deletions test/test_alloc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ function test_alloc()
@test (VERSION < v"1.5") || (inplace_usymlq_bytes == 208)

# TriCG needs:
# - 6 n-vectors: yₖ, uₖ₋₁, uₖ, gy₂ₖ₋₁, gy₂ₖ, p
# - 6 m-vectors: xₖ, vₖ₋₁, vₖ, gx₂ₖ₋₁, gx₂ₖ, q
storage_tricg(n, m) = 6 * n + 6 * m
# - 5 n-vectors: yₖ, uₖ₋₁, uₖ, gy₂ₖ₋₁, gy₂ₖ
# - 5 m-vectors: xₖ, vₖ₋₁, vₖ, gx₂ₖ₋₁, gx₂ₖ
storage_tricg(n, m) = 5 * n + 5 * m
storage_tricg_bytes(n, m) = 8 * storage_tricg(n, m)

expected_tricg_bytes = storage_tricg_bytes(n, m)
Expand Down

0 comments on commit 6d9e9ad

Please sign in to comment.