Skip to content

Commit

Permalink
Use 5-arguments mul! in TRIMR
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Jun 29, 2021
1 parent 493173e commit 9a35a4e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 37 deletions.
10 changes: 5 additions & 5 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,43 +515,43 @@ mutable struct TrimrSolver{T,S} <: KrylovSolver{T,S}
yₖ :: S
N⁻¹uₖ₋₁ :: S
N⁻¹uₖ :: S
p :: S
gy₂ₖ₋₃ :: S
gy₂ₖ₋₂ :: S
gy₂ₖ₋₁ :: S
gy₂ₖ :: S
xₖ :: S
M⁻¹vₖ₋₁ :: S
M⁻¹vₖ :: S
q :: S
gx₂ₖ₋₃ :: S
gx₂ₖ₋₂ :: S
gx₂ₖ₋₁ :: S
gx₂ₖ :: S
uₖ :: S
uₖ₊₁ :: S
vₖ :: S
vₖ₊₁ :: S

function TrimrSolver(n, m, S)
T = eltype(S)
yₖ = S(undef, m)
N⁻¹uₖ₋₁ = S(undef, m)
N⁻¹uₖ = S(undef, m)
p = S(undef, m)
gy₂ₖ₋₃ = S(undef, m)
gy₂ₖ₋₂ = S(undef, m)
gy₂ₖ₋₁ = S(undef, m)
gy₂ₖ = S(undef, m)
xₖ = S(undef, n)
M⁻¹vₖ₋₁ = S(undef, n)
M⁻¹vₖ = S(undef, n)
q = S(undef, n)
gx₂ₖ₋₃ = S(undef, n)
gx₂ₖ₋₂ = S(undef, n)
gx₂ₖ₋₁ = S(undef, n)
gx₂ₖ = S(undef, n)
uₖ = S(undef, 0)
uₖ₊₁ = S(undef, 0)
vₖ = S(undef, 0)
solver = new{T,S}(yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, uₖ, vₖ)
vₖ₊₁ = S(undef, 0)
solver = new{T,S}(yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, uₖ, uₖ₊₁, vₖ, vₖ₊₁)
return solver
end

Expand Down
4 changes: 2 additions & 2 deletions src/tricg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ function tricg!(solver :: TricgSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst
βₖ₊₁ = sqrt(@kdot(m, vₖ₊₁, M⁻¹vₖ₋₁)) # βₖ₊₁ = ‖vₖ₊₁‖_E
γₖ₊₁ = sqrt(@kdot(n, uₖ₊₁, N⁻¹uₖ₋₁)) # γₖ₊₁ = ‖uₖ₊₁‖_F

if βₖ₊₁ zero(T)
if βₖ₊₁ 0
@kscal!(m, one(T) / βₖ₊₁, M⁻¹vₖ₋₁)
MisI || @kscal!(m, one(T) / βₖ₊₁, vₖ₊₁)
end

if γₖ₊₁ zero(T)
if γₖ₊₁ 0
@kscal!(n, one(T) / γₖ₊₁, N⁻¹uₖ₋₁)
NisI || @kscal!(n, one(T) / γₖ₊₁, uₖ₊₁)
end
Expand Down
53 changes: 26 additions & 27 deletions src/trimr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ function trimr!(solver :: TrimrSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst
Aᵀ = A'

# Set up workspace.
allocate_if(!MisI, solver, :vₖ, S, m)
allocate_if(!NisI, solver, :uₖ, S, n)
yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q = solver.yₖ, solver.N⁻¹uₖ₋₁, solver.N⁻¹uₖ, solver.p, solver.xₖ, solver.M⁻¹vₖ₋₁, solver.M⁻¹vₖ, solver.q
allocate_if(!MisI, solver, :vₖ , S, m)
allocate_if(!MisI, solver, :vₖ₊₁ , S, m)
allocate_if(!NisI, solver, :uₖ , S, n)
allocate_if(!NisI, solver, :uₖ₊₁ , S, n)
yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ = solver.yₖ, solver.N⁻¹uₖ₋₁, solver.N⁻¹uₖ, solver.xₖ, solver.M⁻¹vₖ₋₁, solver.M⁻¹vₖ
gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ = solver.gy₂ₖ₋₃, solver.gy₂ₖ₋₂, solver.gy₂ₖ₋₁, solver.gy₂ₖ, solver.gx₂ₖ₋₃, solver.gx₂ₖ₋₂, solver.gx₂ₖ₋₁, solver.gx₂ₖ
vₖ = MisI ? M⁻¹vₖ : solver.vₖ
uₖ = NisI ? N⁻¹uₖ : solver.uₖ
vₖ₊₁ = MisI ? q : M⁻¹vₖ₋₁
uₖ₊₁ = NisI ? p : N⁻¹uₖ₋₁
vₖ₊₁ = MisI ? M⁻¹vₖ₋: solver.vₖ₊
uₖ₊₁ = NisI ? N⁻¹uₖ₋: solver.uₖ₊

# Initial solutions x₀ and y₀.
xₖ .= zero(T)
Expand Down Expand Up @@ -173,33 +175,28 @@ function trimr!(solver :: TrimrSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst
# AUₖ = EVₖTₖ + βₖ₊₁Evₖ₊₁(eₖ)ᵀ = EVₖ₊₁Tₖ₊₁.ₖ
# AᵀVₖ = FUₖ(Tₖ)ᵀ + γₖ₊₁Fuₖ₊₁(eₖ)ᵀ = FUₖ₊₁(Tₖ.ₖ₊₁)ᵀ

mul!(q, A , uₖ) # Forms Evₖ₊₁ : q ← Auₖ
mul!(p, Aᵀ, vₖ) # Forms Fuₖ₊₁ : p ← Aᵀvₖ
mul!(M⁻¹vₖ₋₁, A , uₖ, one(T), -γₖ) # Forms Evₖ₊₁ : Evₖ₋₁ ← Auₖ - γₖEvₖ₋₁
mul!(N⁻¹uₖ₋₁, Aᵀ, vₖ, one(T), -βₖ) # Forms Fuₖ₊₁ : Fuₖ₋₁ ← Aᵀvₖ - βₖFuₖ₋₁

if iter 2
@kaxpy!(m, -γₖ, M⁻¹vₖ₋₁, q) # q ← q - γₖ * M⁻¹vₖ₋₁
@kaxpy!(n, -βₖ, N⁻¹uₖ₋₁, p) # p ← p - βₖ * N⁻¹uₖ₋₁
end

αₖ = @kdot(m, vₖ, q) # αₖ = qᵀvₖ
αₖ = @kdot(m, vₖ, M⁻¹vₖ₋₁) # αₖ = (Auₖ- γₖEvₖ₋₁)ᵀvₖ

@kaxpy!(m, -αₖ, M⁻¹vₖ, q) # qq - αₖ * M⁻¹vₖ
@kaxpy!(n, -αₖ, N⁻¹uₖ, p) # pp - αₖ * N⁻¹uₖ
@kaxpy!(m, -αₖ, M⁻¹vₖ, M⁻¹vₖ₋₁) # Evₖ₋₁Evₖ₋₁ - αₖ * Evₖ
@kaxpy!(n, -αₖ, N⁻¹uₖ, N⁻¹uₖ₋₁) # Fuₖ₋₁Fuₖ₋₁ - αₖ * Fuₖ

# Compute vₖ₊₁ and uₖ₊₁
MisI || mul!(vₖ₊₁, M, q) # βₖ₊₁vₖ₊₁ = MAuₖ - γₖvₖ₋₁ - αₖvₖ
NisI || mul!(uₖ₊₁, N, p) # γₖ₊₁uₖ₊₁ = NAᵀvₖ - βₖuₖ₋₁ - αₖuₖ
MisI || mul!(vₖ₊₁, M, M⁻¹vₖ₋₁) # βₖ₊₁vₖ₊₁ = MAuₖ - γₖvₖ₋₁ - αₖvₖ
NisI || mul!(uₖ₊₁, N, N⁻¹uₖ₋₁) # γₖ₊₁uₖ₊₁ = NAᵀvₖ - βₖuₖ₋₁ - αₖuₖ

βₖ₊₁ = sqrt(@kdot(m, vₖ₊₁, q)) # βₖ₊₁ = ‖vₖ₊₁‖_E
γₖ₊₁ = sqrt(@kdot(n, uₖ₊₁, p)) # γₖ₊₁ = ‖uₖ₊₁‖_F
βₖ₊₁ = sqrt(@kdot(m, vₖ₊₁, M⁻¹vₖ₋₁)) # βₖ₊₁ = ‖vₖ₊₁‖_E
γₖ₊₁ = sqrt(@kdot(n, uₖ₊₁, N⁻¹uₖ₋₁)) # γₖ₊₁ = ‖uₖ₊₁‖_F

if βₖ₊₁ 0
@kscal!(m, one(T) / βₖ₊₁, q)
@kscal!(m, one(T) / βₖ₊₁, M⁻¹vₖ₋₁)
MisI || @kscal!(m, one(T) / βₖ₊₁, vₖ₊₁)
end

if γₖ₊₁ 0
@kscal!(n, one(T) / γₖ₊₁, p)
@kscal!(n, one(T) / γₖ₊₁, N⁻¹uₖ₋₁)
NisI || @kscal!(n, one(T) / γₖ₊₁, uₖ₊₁)
end

Expand Down Expand Up @@ -371,13 +368,15 @@ function trimr!(solver :: TrimrSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst
MisI || (vₖ .= vₖ₊₁)
NisI || (uₖ .= uₖ₊₁)

# Update M⁻¹vₖ₋₁ and N⁻¹uₖ₋₁
@. M⁻¹vₖ₋₁ = M⁻¹vₖ
@. N⁻¹uₖ₋₁ = N⁻¹uₖ
# Update N⁻¹uₖ₋₁, M⁻¹vₖ₋₁, N⁻¹uₖ and M⁻¹vₖ.
@kswap(M⁻¹vₖ₋₁, M⁻¹vₖ)
@kswap(N⁻¹uₖ₋₁, N⁻¹uₖ)

# Update M⁻¹vₖ and N⁻¹uₖ
@. M⁻¹vₖ = q
@. N⁻¹uₖ = p
# Update pointers impacted by the swaps.
MisI && (vₖ = M⁻¹vₖ)
MisI && (vₖ₊₁ = M⁻¹vₖ₋₁)
NisI && (uₖ = N⁻¹uₖ)
NisI && (uₖ₊₁ = N⁻¹uₖ₋₁)

# Update cosines and sines
old_s₁ₖ = s₁ₖ
Expand Down
6 changes: 3 additions & 3 deletions test/test_alloc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ function test_alloc()
@test (VERSION < v"1.5") || (inplace_tricg_bytes == 208)

# TriMR needs:
# - 8 n-vectors: yₖ, uₖ₋₁, uₖ, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, p
# - 8 m-vectors: xₖ, vₖ₋₁, vₖ, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, q
storage_trimr(n, m) = 8 * n + 8 * m
# - 7 n-vectors: yₖ, uₖ₋₁, uₖ, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ
# - 7 m-vectors: xₖ, vₖ₋₁, vₖ, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ
storage_trimr(n, m) = 7 * n + 7 * m
storage_trimr_bytes(n, m) = 8 * storage_trimr(n, m)

expected_trimr_bytes = storage_trimr_bytes(n, m)
Expand Down

0 comments on commit 9a35a4e

Please sign in to comment.