From 6d9e9ad2227fe41cfbe60103c86c6e2a1789d7bc Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Tue, 29 Jun 2021 00:30:07 -0400 Subject: [PATCH] Use 5-arguments mul! in TRICG --- src/krylov_solvers.jl | 6 +---- src/tricg.jl | 55 ++++++++++++++++++++----------------------- test/test_alloc.jl | 6 ++--- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/src/krylov_solvers.jl b/src/krylov_solvers.jl index 8b27588cc..000ffd616 100644 --- a/src/krylov_solvers.jl +++ b/src/krylov_solvers.jl @@ -468,13 +468,11 @@ mutable struct TricgSolver{T,S} <: KrylovSolver{T,S} yₖ :: S N⁻¹uₖ₋₁ :: S N⁻¹uₖ :: S - p :: S gy₂ₖ₋₁ :: S gy₂ₖ :: S xₖ :: S M⁻¹vₖ₋₁ :: S M⁻¹vₖ :: S - q :: S gx₂ₖ₋₁ :: S gx₂ₖ :: S uₖ :: S @@ -485,18 +483,16 @@ mutable struct TricgSolver{T,S} <: KrylovSolver{T,S} yₖ = S(undef, m) N⁻¹uₖ₋₁ = S(undef, m) N⁻¹uₖ = S(undef, m) - p = S(undef, m) gy₂ₖ₋₁ = S(undef, m) gy₂ₖ = S(undef, m) xₖ = S(undef, n) M⁻¹vₖ₋₁ = S(undef, n) M⁻¹vₖ = S(undef, n) - q = S(undef, n) gx₂ₖ₋₁ = S(undef, n) gx₂ₖ = S(undef, n) uₖ = S(undef, 0) vₖ = S(undef, 0) - solver = new{T,S}(yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, uₖ, vₖ) + solver = new{T,S}(yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, gy₂ₖ₋₁, gy₂ₖ, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, gx₂ₖ₋₁, gx₂ₖ, uₖ, vₖ) return solver end diff --git a/src/tricg.jl b/src/tricg.jl index 09491fe05..fd968536b 100644 --- a/src/tricg.jl +++ b/src/tricg.jl @@ -87,12 +87,12 @@ function tricg!(solver :: TricgSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst # Set up workspace. allocate_if(!MisI, solver, :vₖ, S, m) allocate_if(!NisI, solver, :uₖ, S, n) - yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p, xₖ, M⁻¹vₖ₋₁ = solver.yₖ, solver.N⁻¹uₖ₋₁, solver.p, solver.N⁻¹uₖ, solver.xₖ, solver.M⁻¹vₖ₋₁ - M⁻¹vₖ, q, gy₂ₖ₋₁, gy₂ₖ, gx₂ₖ₋₁, gx₂ₖ = solver.M⁻¹vₖ, solver.q, solver.gy₂ₖ₋₁, solver.gy₂ₖ, solver.gx₂ₖ₋₁, solver.gx₂ₖ + yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, xₖ, M⁻¹vₖ₋₁ = solver.yₖ, solver.N⁻¹uₖ₋₁, solver.N⁻¹uₖ, solver.xₖ, solver.M⁻¹vₖ₋₁ + M⁻¹vₖ, gy₂ₖ₋₁, gy₂ₖ, gx₂ₖ₋₁, gx₂ₖ = solver.M⁻¹vₖ, solver.gy₂ₖ₋₁, solver.gy₂ₖ, solver.gx₂ₖ₋₁, solver.gx₂ₖ vₖ = MisI ? M⁻¹vₖ : solver.vₖ uₖ = NisI ? N⁻¹uₖ : solver.uₖ - vₖ₊₁ = MisI ? q : vₖ - uₖ₊₁ = NisI ? p : uₖ + vₖ₊₁ = MisI ? M⁻¹vₖ₋₁ : vₖ + uₖ₊₁ = NisI ? N⁻¹uₖ₋₁ : uₖ # Initial solutions x₀ and y₀. xₖ .= zero(T) @@ -160,22 +160,13 @@ function tricg!(solver :: TricgSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst # AUₖ = EVₖTₖ + βₖ₊₁Evₖ₊₁(eₖ)ᵀ = EVₖ₊₁Tₖ₊₁.ₖ # AᵀVₖ = FUₖ(Tₖ)ᵀ + γₖ₊₁Fuₖ₊₁(eₖ)ᵀ = FUₖ₊₁(Tₖ.ₖ₊₁)ᵀ - mul!(q, A , uₖ) # Forms Evₖ₊₁ : q ← Auₖ - mul!(p, Aᵀ, vₖ) # Forms Fuₖ₊₁ : p ← Aᵀvₖ + mul!(M⁻¹vₖ₋₁, A , uₖ, one(T), -γₖ) # Forms Evₖ₊₁ : Evₖ₋₁ ← Auₖ - γₖEvₖ₋₁ + mul!(N⁻¹uₖ₋₁, Aᵀ, vₖ, one(T), -βₖ) # Forms Fuₖ₊₁ : Fuₖ₋₁ ← Aᵀvₖ - βₖFuₖ₋₁ - if iter ≥ 2 - @kaxpy!(m, -γₖ, M⁻¹vₖ₋₁, q) # q ← q - γₖ * M⁻¹vₖ₋₁ - @kaxpy!(n, -βₖ, N⁻¹uₖ₋₁, p) # p ← p - βₖ * N⁻¹uₖ₋₁ - end - - αₖ = @kdot(m, vₖ, q) # αₖ = qᵀvₖ + αₖ = @kdot(m, vₖ, M⁻¹vₖ₋₁) # αₖ = (Auₖ- γₖEvₖ₋₁)ᵀvₖ - @kaxpy!(m, -αₖ, M⁻¹vₖ, q) # q ← q - αₖ * M⁻¹vₖ - @kaxpy!(n, -αₖ, N⁻¹uₖ, p) # p ← p - αₖ * N⁻¹uₖ - - # Update M⁻¹vₖ₋₁ and N⁻¹uₖ₋₁ - @. M⁻¹vₖ₋₁ = M⁻¹vₖ - @. N⁻¹uₖ₋₁ = N⁻¹uₖ + @kaxpy!(m, -αₖ, M⁻¹vₖ, M⁻¹vₖ₋₁) # Evₖ₋₁ ← Evₖ₋₁ - αₖ * Evₖ + @kaxpy!(n, -αₖ, N⁻¹uₖ, N⁻¹uₖ₋₁) # Fuₖ₋₁ ← Fuₖ₋₁ - αₖ * Fuₖ # Notations : Wₖ = [w₁ ••• wₖ] = [v₁ 0 ••• vₖ 0 ] # [0 u₁ ••• 0 uₖ] @@ -266,25 +257,31 @@ function tricg!(solver :: TricgSolver{T,S}, A, b :: AbstractVector{T}, c :: Abst @. yₖ += π₂ₖ₋₁ * gy₂ₖ₋₁ + π₂ₖ * gy₂ₖ # Compute vₖ₊₁ and uₖ₊₁ - MisI || mul!(vₖ₊₁, M, q) # βₖ₊₁vₖ₊₁ = MAuₖ - γₖvₖ₋₁ - αₖvₖ - NisI || mul!(uₖ₊₁, N, p) # γₖ₊₁uₖ₊₁ = NAᵀvₖ - βₖuₖ₋₁ - αₖuₖ + MisI || mul!(vₖ₊₁, M, M⁻¹vₖ₋₁) # βₖ₊₁vₖ₊₁ = MAuₖ - γₖvₖ₋₁ - αₖvₖ + NisI || mul!(uₖ₊₁, N, N⁻¹uₖ₋₁) # γₖ₊₁uₖ₊₁ = NAᵀvₖ - βₖuₖ₋₁ - αₖuₖ - βₖ₊₁ = sqrt(@kdot(m, vₖ₊₁, q)) # βₖ₊₁ = ‖vₖ₊₁‖_E - γₖ₊₁ = sqrt(@kdot(n, uₖ₊₁, p)) # γₖ₊₁ = ‖uₖ₊₁‖_F + βₖ₊₁ = sqrt(@kdot(m, vₖ₊₁, M⁻¹vₖ₋₁)) # βₖ₊₁ = ‖vₖ₊₁‖_E + γₖ₊₁ = sqrt(@kdot(n, uₖ₊₁, N⁻¹uₖ₋₁)) # γₖ₊₁ = ‖uₖ₊₁‖_F - if βₖ₊₁ ≠ 0 - @kscal!(m, one(T) / βₖ₊₁, q) + if βₖ₊₁ ≠ zero(T) + @kscal!(m, one(T) / βₖ₊₁, M⁻¹vₖ₋₁) MisI || @kscal!(m, one(T) / βₖ₊₁, vₖ₊₁) end - if γₖ₊₁ ≠ 0 - @kscal!(n, one(T) / γₖ₊₁, p) + if γₖ₊₁ ≠ zero(T) + @kscal!(n, one(T) / γₖ₊₁, N⁻¹uₖ₋₁) NisI || @kscal!(n, one(T) / γₖ₊₁, uₖ₊₁) end - # Update M⁻¹vₖ and N⁻¹uₖ - @. M⁻¹vₖ = q - @. N⁻¹uₖ = p + # Update N⁻¹uₖ₋₁, M⁻¹vₖ₋₁, N⁻¹uₖ and M⁻¹vₖ. + @kswap(N⁻¹uₖ₋₁, N⁻¹uₖ) + @kswap(M⁻¹vₖ₋₁, M⁻¹vₖ) + + # Update pointers impacted by the swaps. + NisI && (uₖ = N⁻¹uₖ) + NisI && (uₖ₊₁ = N⁻¹uₖ₋₁) + MisI && (vₖ = M⁻¹vₖ) + MisI && (vₖ₊₁ = M⁻¹vₖ₋₁) # Compute ‖rₖ‖² = (γₖ₊₁ζ₂ₖ₋₁)² + (βₖ₊₁ζ₂ₖ)² rNorm = sqrt((γₖ₊₁ * (π₂ₖ₋₁ - δₖ*π₂ₖ))^2 + (βₖ₊₁ * π₂ₖ)^2) diff --git a/test/test_alloc.jl b/test/test_alloc.jl index 7c99b3577..f2019c4cc 100644 --- a/test/test_alloc.jl +++ b/test/test_alloc.jl @@ -440,9 +440,9 @@ function test_alloc() @test (VERSION < v"1.5") || (inplace_usymlq_bytes == 208) # TriCG needs: - # - 6 n-vectors: yₖ, uₖ₋₁, uₖ, gy₂ₖ₋₁, gy₂ₖ, p - # - 6 m-vectors: xₖ, vₖ₋₁, vₖ, gx₂ₖ₋₁, gx₂ₖ, q - storage_tricg(n, m) = 6 * n + 6 * m + # - 5 n-vectors: yₖ, uₖ₋₁, uₖ, gy₂ₖ₋₁, gy₂ₖ + # - 5 m-vectors: xₖ, vₖ₋₁, vₖ, gx₂ₖ₋₁, gx₂ₖ + storage_tricg(n, m) = 5 * n + 5 * m storage_tricg_bytes(n, m) = 8 * storage_tricg(n, m) expected_tricg_bytes = storage_tricg_bytes(n, m)