Skip to content

Commit

Permalink
Swap the arguments in the macro kcopy!
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 16, 2024
1 parent 2cd76f7 commit 8621db3
Show file tree
Hide file tree
Showing 20 changed files with 37 additions and 41 deletions.
4 changes: 2 additions & 2 deletions src/bicgstab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ kwargs_bicgstab = (:c, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose,
mul!(q, A, y) # qₖ = Ayₖ
mulorldiv!(v, M, q, ldiv) # vₖ = M⁻¹qₖ
α = ρ / @kdot(n, c, v) # αₖ = ⟨r̅₀,rₖ₋₁⟩ / ⟨r̅₀,vₖ⟩
@kcopy!(n, r, s) # sₖ = rₖ₋₁
@kcopy!(n, s, r) # sₖ = rₖ₋₁
@kaxpy!(n, -α, v, s) # sₖ = sₖ - αₖvₖ
@kaxpy!(n, α, y, x) # xₐᵤₓ = xₖ₋₁ + αₖyₖ
NisI || mulorldiv!(z, N, s, ldiv) # zₖ = N⁻¹sₖ
mul!(d, A, z) # dₖ = Azₖ
MisI || mulorldiv!(t, M, d, ldiv) # tₖ = M⁻¹dₖ
ω = @kdot(n, t, s) / @kdot(n, t, t) # ⟨tₖ,sₖ⟩ / ⟨tₖ,tₖ⟩
@kaxpy!(n, ω, z, x) # xₖ = xₐᵤₓ + ωₖzₖ
@kcopy!(n, s, r) # rₖ = sₖ
@kcopy!(n, r, s) # rₖ = sₖ
@kaxpy!(n, -ω, t, r) # rₖ = rₖ - ωₖtₖ
next_ρ = @kdot(n, c, r) # ρₖ₊₁ = ⟨r̅₀,rₖ⟩
β = (next_ρ / ρ) */ ω) # βₖ₊₁ = (ρₖ₊₁ / ρₖ) * (αₖ / ωₖ)
Expand Down
6 changes: 3 additions & 3 deletions src/bilq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,15 @@ kwargs_bilq = (:c, :transfer_to_bicg, :M, :N, :ldiv, :atol, :rtol, :itmax, :time
# Compute d̅ₖ.
if iter == 1
# d̅₁ = v₁
@kcopy!(n, vₖ, d̅) # d̅ ← vₖ
@kcopy!(n, d̅, vₖ) # d̅ ← vₖ
else
# d̅ₖ = s̄ₖ * d̅ₖ₋₁ - cₖ * vₖ
@kaxpby!(n, -cₖ, vₖ, conj(sₖ), d̅)
end

# Compute vₖ₊₁ and uₖ₊₁.
@kcopy!(n, vₖ, vₖ₋₁) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ, uₖ₋₁) # uₖ₋₁ ← uₖ
@kcopy!(n, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ

if pᴴq 0
vₖ .= q ./ βₖ₊₁ # βₖ₊₁vₖ₊₁ = q
Expand Down
6 changes: 3 additions & 3 deletions src/bilqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ kwargs_bilqr = (:transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose, :hi
# Compute d̅ₖ.
if iter == 1
# d̅₁ = v₁
@kcopy!(n, vₖ, d̅) # d̅ ← vₖ
@kcopy!(n, d̅, vₖ) # d̅ ← vₖ
else
# d̅ₖ = s̄ₖ * d̅ₖ₋₁ - cₖ * vₖ
@kaxpby!(n, -cₖ, vₖ, conj(sₖ), d̅)
Expand Down Expand Up @@ -400,8 +400,8 @@ kwargs_bilqr = (:transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose, :hi
end

# Compute vₖ₊₁ and uₖ₊₁.
@kcopy!(n, vₖ, vₖ₋₁) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ, uₖ₋₁) # uₖ₋₁ ← uₖ
@kcopy!(n, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ

if pᴴq zero(FC)
vₖ .= q ./ βₖ₊₁ # βₖ₊₁vₖ₊₁ = q
Expand Down
4 changes: 2 additions & 2 deletions src/cg_lanczos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ kwargs_cg_lanczos = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :timemax
@kaxpy!(n, -δ, Mv, Mv_next) # Mvₖ₊₁ ← Mvₖ₊₁ - δₖMvₖ
if iter > 0
@kaxpy!(n, -β, Mv_prev, Mv_next) # Mvₖ₊₁ ← Mvₖ₊₁ - βₖMvₖ₋₁
@kcopy!(n, Mv, Mv_prev) # Mvₖ₋₁ ← Mvₖ
@kcopy!(n, Mv_prev, Mv) # Mvₖ₋₁ ← Mvₖ
end
@kcopy!(n, Mv_next, Mv) # Mvₖ ← Mvₖ₊₁
@kcopy!(n, Mv, Mv_next) # Mvₖ ← Mvₖ₊₁
MisI || mulorldiv!(v, M, Mv, ldiv) # vₖ₊₁ = M⁻¹ * Mvₖ₊₁
β = sqrt(@kdotr(n, v, Mv)) # βₖ₊₁ = vₖ₊₁ᴴ M vₖ₊₁
@kscal!(n, one(FC) / β, v) # vₖ₊₁ ← vₖ₊₁ / βₖ₊₁
Expand Down
4 changes: 2 additions & 2 deletions src/cg_lanczos_shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ kwargs_cg_lanczos_shift = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :t
@kaxpy!(n, -δ, Mv, Mv_next) # Mvₖ₊₁ ← Mvₖ₊₁ - δₖMvₖ
if iter > 0
@kaxpy!(n, -β, Mv_prev, Mv_next) # Mvₖ₊₁ ← Mvₖ₊₁ - βₖMvₖ₋₁
@kcopy!(n, Mv, Mv_prev) # Mvₖ₋₁ ← Mvₖ
@kcopy!(n, Mv_prev, Mv) # Mvₖ₋₁ ← Mvₖ
end
@kcopy!(n, Mv_next, Mv) # Mvₖ ← Mvₖ₊₁
@kcopy!(n, Mv, Mv_next) # Mvₖ ← Mvₖ₊₁
MisI || mulorldiv!(v, M, Mv, ldiv) # vₖ₊₁ = M⁻¹ * Mvₖ₊₁
β = sqrt(@kdotr(n, v, Mv)) # βₖ₊₁ = vₖ₊₁ᴴ M vₖ₊₁
@kscal!(n, one(FC) / β, v) # vₖ₊₁ ← vₖ₊₁ / βₖ₊₁
Expand Down
4 changes: 2 additions & 2 deletions src/cgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ kwargs_cgs = (:c, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :hist
MisI || mulorldiv!(v, M, t, ldiv) # vₖ = M⁻¹tₖ
σ = @kdot(n, c, v) # σₖ = ⟨ r̅₀,M⁻¹AN⁻¹pₖ ⟩
α = ρ / σ # αₖ = ρₖ / σₖ
@kcopy!(n, u, q) # qₖ = uₖ
@kcopy!(n, q, u) # qₖ = uₖ
@kaxpy!(n, -α, v, q) # qₖ = qₖ - αₖ * M⁻¹AN⁻¹pₖ
@kaxpy!(n, one(FC), q, u) # uₖ₊½ = uₖ + qₖ
NisI || mulorldiv!(z, N, u, ldiv) # zₖ = N⁻¹uₖ₊½
Expand All @@ -218,7 +218,7 @@ kwargs_cgs = (:c, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :hist
@kaxpy!(n, -α, w, r) # rₖ₊₁ = rₖ - αₖ * M⁻¹AN⁻¹(uₖ + qₖ)
ρ_next = @kdot(n, c, r) # ρₖ₊₁ = ⟨ r̅₀,rₖ₊₁ ⟩
β = ρ_next / ρ # βₖ = ρₖ₊₁ / ρₖ
@kcopy!(n, r, u) # uₖ₊₁ = rₖ₊₁
@kcopy!(n, u, r) # uₖ₊₁ = rₖ₊₁
@kaxpy!(n, β, q, u) # uₖ₊₁ = uₖ₊₁ + βₖ * qₖ
@kaxpby!(n, one(FC), q, β, p) # pₐᵤₓ = qₖ + βₖ * pₖ
@kaxpby!(n, one(FC), u, β, p) # pₖ₊₁ = uₖ₊₁ + βₖ * pₐᵤₓ
Expand Down
6 changes: 1 addition & 5 deletions src/krylov_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,7 @@ macro kaxpby!(n, s, x, t, y)
return esc(:(Krylov.kaxpby!($n, $s, $x, 1, $t, $y, 1)))
end

macro kcopy!(n, x, y)
return esc(:(Krylov.kcopy!($n, $x, 1, $y, 1)))
end

macro kcopyto!(n, y, x)
macro kcopy!(n, y, x)
return esc(:(Krylov.kcopy!($n, $x, 1, $y, 1)))
end

Expand Down
4 changes: 2 additions & 2 deletions src/minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax,
end
@kaxpy!(n, one(FC) / β, v, w)

@kcopy!(n, r2, r1) # r1 ← r2
@kcopy!(n, y, r2) # r2 ← y
@kcopy!(n, r1, r2) # r1 ← r2
@kcopy!(n, r2, y) # r2 ← y
MisI || mulorldiv!(v, M, r2, ldiv)
oldβ = β
β = @kdotr(n, r2, v)
Expand Down
2 changes: 1 addition & 1 deletion src/minres_qlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ kwargs_minres_qlp = (:M, :ldiv, :λ, :atol, :rtol, :Artol, :itmax, :timemax, :ve
# Compute directions wₖ₋₂, ẘₖ₋₁ and w̄ₖ, last columns of Wₖ = Vₖ(Pₖ)ᴴ
if iter == 1
# w̅₁ = v₁
@kcopy!(n, vₖ, wₖ)
@kcopy!(n, wₖ, vₖ)
elseif iter == 2
# [w̅ₖ₋₁ vₖ] [cpₖ spₖ] = [ẘₖ₋₁ w̅ₖ] ⟷ ẘₖ₋₁ = cpₖ * w̅ₖ₋₁ + spₖ * vₖ
# [spₖ -cpₖ] ⟷ w̅ₖ = spₖ * w̅ₖ₋₁ - cpₖ * vₖ
Expand Down
4 changes: 2 additions & 2 deletions src/qmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ kwargs_qmr = (:c, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :hist
@kaxpy!(n, ζₖ, wₖ, x)

# Compute vₖ₊₁ and uₖ₊₁.
@kcopy!(n, vₖ, vₖ₋₁) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ, uₖ₋₁) # uₖ₋₁ ← uₖ
@kcopy!(n, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ

if pᴴq zero(FC)
vₖ .= q ./ βₖ₊₁ # βₖ₊₁vₖ₊₁ = q
Expand Down
4 changes: 2 additions & 2 deletions src/symmlq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ kwargs_symmlq = (:M, :ldiv, :transfer_to_cg, :λ, :λest, :atol, :rtol, :etol, :
mul!(Mv_next, A, v)
α = @kdotr(m, v, Mv_next) + λ
@kaxpy!(m, -oldβ, Mvold, Mv_next)
@kcopy!(m, Mv, Mvold) # Mvold ← Mv
@kcopy!(m, Mvold, Mv) # Mvold ← Mv
@kaxpy!(m, -α, Mv, Mv_next)
@kcopy!(m, Mv_next, Mv) # Mv ← Mv_next
@kcopy!(m, Mv, Mv_next) # Mv ← Mv_next
MisI || mulorldiv!(v, M, Mv, ldiv)
β = @kdotr(m, v, Mv)
β < 0 && error("Preconditioner is not positive definite")
Expand Down
4 changes: 2 additions & 2 deletions src/tricg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ kwargs_tricg = (:M, :N, :ldiv, :spd, :snd, :flip, :τ, :ν, :atol, :rtol, :itmax
if iter == 1
# [ 1 0 ] [ gx₁ gy₁ ] = [ v₁ 0 ]
# [ δ̄₁ 1 ] [ gx₂ gy₂ ] [ 0 u₁ ]
@kcopy!(m, vₖ, gx₂ₖ₋₁) # gx₂ₖ₋₁ ← vₖ
@kcopy!(m, gx₂ₖ₋₁, vₖ) # gx₂ₖ₋₁ ← vₖ
gx₂ₖ .= -conj(δₖ) .* gx₂ₖ₋₁
@kcopy!(n, uₖ, gy₂ₖ) # gy₂ₖ ← uₖ
@kcopy!(n, gy₂ₖ, uₖ) # gy₂ₖ ← uₖ
else
# [ 0 σ̄ₖ 1 0 ] [ gx₂ₖ₋₃ gy₂ₖ₋₃ ] = [ vₖ 0 ]
# [ η̄ₖ λ̄ₖ δ̄ₖ 1 ] [ gx₂ₖ₋₂ gy₂ₖ₋₂ ] [ 0 uₖ ]
Expand Down
6 changes: 3 additions & 3 deletions src/trilqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ kwargs_trilqr = (:transfer_to_usymcg, :atol, :rtol, :itmax, :timemax, :verbose,
# Compute d̅ₖ.
if iter == 1
# d̅₁ = u₁
@kcopy!(n, uₖ, d̅) # d̅ ← uₖ
@kcopy!(n, d̅, uₖ) # d̅ ← uₖ
else
# d̅ₖ = s̄ₖ * d̅ₖ₋₁ - cₖ * uₖ
@kaxpby!(n, -cₖ, uₖ, conj(sₖ), d̅)
Expand Down Expand Up @@ -378,8 +378,8 @@ kwargs_trilqr = (:transfer_to_usymcg, :atol, :rtol, :itmax, :timemax, :verbose,
end

# Compute uₖ₊₁ and uₖ₊₁.
@kcopy!(m, vₖ, vₖ₋₁) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ, uₖ₋₁) # uₖ₋₁ ← uₖ
@kcopy!(m, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ

if βₖ₊₁ zero(T)
vₖ .= q ./ βₖ₊₁ # βₖ₊₁vₖ₊₁ = q
Expand Down
6 changes: 3 additions & 3 deletions src/usymlq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,15 @@ kwargs_usymlq = (:transfer_to_usymcg, :atol, :rtol, :itmax, :timemax, :verbose,
# Compute d̅ₖ.
if iter == 1
# d̅₁ = u₁
@kcopy!(n, uₖ, d̅) # d̅ ← vₖ
@kcopy!(n, d̅, uₖ) # d̅ ← vₖ
else
# d̅ₖ = s̄ₖ * d̅ₖ₋₁ - cₖ * uₖ
@kaxpby!(n, -cₖ, uₖ, conj(sₖ), d̅)
end

# Compute uₖ₊₁ and uₖ₊₁.
@kcopy!(m, vₖ, vₖ₋₁) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ, uₖ₋₁) # uₖ₋₁ ← uₖ
@kcopy!(m, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ

if βₖ₊₁ zero(T)
vₖ .= q ./ βₖ₊₁ # βₖ₊₁vₖ₊₁ = q
Expand Down
4 changes: 2 additions & 2 deletions src/usymqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ kwargs_usymqr = (:atol, :rtol, :itmax, :timemax, :verbose, :history, :callback,
history && push!(AᴴrNorms, AᴴrNorm)

# Compute uₖ₊₁ and uₖ₊₁.
@kcopy!(m, vₖ, vₖ₋₁) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ, uₖ₋₁) # uₖ₋₁ ← uₖ
@kcopy!(m, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ
@kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ

if βₖ₊₁ zero(T)
vₖ .= q ./ βₖ₊₁ # βₖ₊₁vₖ₊₁ = q
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/amd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ include("gpu.jl")
end

@testset "kcopy! -- $FC" begin
Krylov.@kcopy!(n, x, y)
Krylov.@kcopy!(n, y, x)
end

@testset "kswap -- $FC" begin
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/intel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ include("gpu.jl")
end

@testset "kcopy! -- $FC" begin
Krylov.@kcopy!(n, x, y)
Krylov.@kcopy!(n, y, x)
end

@testset "kswap -- $FC" begin
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ include("gpu.jl")
end

@testset "kcopy! -- $FC" begin
Krylov.@kcopy!(n, x, y)
Krylov.@kcopy!(n, y, x)
end

@testset "kswap -- $FC" begin
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/nvidia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ include("gpu.jl")
end

@testset "kcopy! -- $FC" begin
Krylov.@kcopy!(n, x, y)
Krylov.@kcopy!(n, y, x)
end

@testset "kswap -- $FC" begin
Expand Down
2 changes: 1 addition & 1 deletion test/test_aux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
Krylov.@kaxpby!(n, a, x, b2, y)
Krylov.@kaxpby!(n, a2, x, b2, y)

Krylov.@kcopy!(n, x, y)
Krylov.@kcopy!(n, y, x)

Krylov.@kfill!(x, a)

Expand Down

0 comments on commit 8621db3

Please sign in to comment.