-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add COCG method for complex symmetric linear systems #289
base: master
Are you sure you want to change the base?
Changes from 3 commits
c5b440e
91208cd
3337884
f7df543
d5b31f4
97315be
593e88c
835a894
9c47e7f
3cd7969
d16fecb
04c3c16
4800129
9a7fb26
f3710c3
b457247
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,20 +2,29 @@ import Base: iterate | |
using Printf | ||
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables | ||
|
||
mutable struct CGIterable{matT, solT, vecT, numT <: Real} | ||
# Conjugated dot product | ||
_dot(x, ::Val{true}) = sum(abs2, x) # for x::Complex, returns Real | ||
_dot(x, y, ::Val{true}) = dot(x, y) | ||
|
||
# Unconjugated dot product | ||
_dot(x, ::Val{false}) = sum(xₖ^2 for xₖ in x) | ||
_dot(x, y, ::Val{false}) = sum(prod, zip(x,y)) | ||
|
||
mutable struct CGIterable{matT, solT, vecT, numT <: Real, paramT <: Number, boolT <: Union{Val{true},Val{false}}} | ||
A::matT | ||
x::solT | ||
r::vecT | ||
c::vecT | ||
u::vecT | ||
tol::numT | ||
residual::numT | ||
prev_residual::numT | ||
ρ_prev::paramT | ||
maxiter::Int | ||
mv_products::Int | ||
conjugate_dot::boolT | ||
end | ||
|
||
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number} | ||
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number, boolT <: Union{Val{true},Val{false}}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you turn it into two named structs instead of Val{true/false}, it would be easier to read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand defining cg!(x, A, b; conjugate_dot=true) would be easier for the users than something like cg!(x, A, b; innerprodtype=UnconjugatedProduct()) Introduction of a new type should be minimized, because it reduces portability. Also, I don't think using a boolean flag reduces readability in this case, because the flag is a keyword argument. If we agree that passing a boolean flag CGIterable(A, x, r, ..., Val(conjugate_dot)) than CGIterable(A, x, r, ..., conjugate_dot ? InnerProduct() : UnconjugatedProduct()) This is why I decided to include a boolean type in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking the high-level interface could just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean by "reducing portability," Wonseok? A new type has the advantage of allowing more than two possibilities. For example, one could imagine a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Please ignore my comment on portability. I think it was irrelevant to the current situation. (For clarification and my future reference, I meant it is not desirable to define a package-specific collection type that aggregates various fields and to pass such a collection to a function in order to reduce the number of function arguments.)
I really like this flexibility!
I didn't know that constant propagation can be a means to avoid dynamic dispatch. Glad to learn a new thing. Thanks for the information! (I find this useful on this topic.) But before knowing about constant propagation, I thought the the overhead of a dynamic dispatch should be negligible when iterative solvers are useful, as you pointed out. I have pushed a commit complying with this request. |
||
Pl::precT | ||
A::matT | ||
x::solT | ||
|
@@ -24,9 +33,10 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb | |
u::vecT | ||
tol::numT | ||
residual::numT | ||
ρ::paramT | ||
ρ_prev::paramT | ||
maxiter::Int | ||
mv_products::Int | ||
conjugate_dot::boolT | ||
end | ||
|
||
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol | ||
|
@@ -47,18 +57,19 @@ function iterate(it::CGIterable, iteration::Int=start(it)) | |
end | ||
|
||
# u := r + βu (almost an axpy) | ||
β = it.residual^2 / it.prev_residual^2 | ||
ρ = _dot(it.r, it.conjugate_dot) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CG and preconditioned CG are split into separate methods because CG requires one fewer dot product. Can you check if you can drop this dot in your version somehow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, but I'm afraid we can't. It turns out that the saving of one dot product in the non-preconditioned CG comes from the fact that we calculate In COCG, we use the unconjugated dot product, so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case you should probably have β = it.conjugate_dot isa Val{true} ? it.residual^2 / it.prev_residual^2 : _dot(it.r, it.conjugate_dot) so that it only does the extra dot product in the unconjugated case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And if you do that please change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
β = ρ / it.ρ_prev | ||
it.u .= it.r .+ β .* it.u | ||
|
||
# c = A * u | ||
mul!(it.c, it.A, it.u) | ||
α = it.residual^2 / dot(it.u, it.c) | ||
α = ρ / _dot(it.u, it.c, it.conjugate_dot) | ||
|
||
# Improve solution and residual | ||
it.ρ_prev = ρ | ||
it.x .+= α .* it.u | ||
it.r .-= α .* it.c | ||
|
||
it.prev_residual = it.residual | ||
it.residual = norm(it.r) | ||
|
||
# Return the residual at item and iteration number as state | ||
|
@@ -78,18 +89,17 @@ function iterate(it::PCGIterable, iteration::Int=start(it)) | |
# Apply left preconditioner | ||
ldiv!(it.c, it.Pl, it.r) | ||
|
||
ρ_prev = it.ρ | ||
it.ρ = dot(it.c, it.r) | ||
|
||
# u := c + βu (almost an axpy) | ||
β = it.ρ / ρ_prev | ||
ρ = _dot(it.r, it.c, it.conjugate_dot) | ||
β = ρ / it.ρ_prev | ||
it.u .= it.c .+ β .* it.u | ||
|
||
# c = A * u | ||
mul!(it.c, it.A, it.u) | ||
α = it.ρ / dot(it.u, it.c) | ||
α = ρ / _dot(it.u, it.c, it.conjugate_dot) | ||
|
||
# Improve solution and residual | ||
it.ρ_prev = ρ | ||
it.x .+= α .* it.u | ||
it.r .-= α .* it.c | ||
|
||
|
@@ -122,7 +132,8 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |
reltol::Real = sqrt(eps(real(eltype(b)))), | ||
maxiter::Int = size(A, 2), | ||
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||
initially_zero::Bool = false) | ||
initially_zero::Bool = false, | ||
conjugate_dot::Bool = true) | ||
u = statevars.u | ||
r = statevars.r | ||
c = statevars.c | ||
|
@@ -142,15 +153,13 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |
|
||
# Return the iterable | ||
if isa(Pl, Identity) | ||
return CGIterable(A, x, r, c, u, | ||
tolerance, residual, one(residual), | ||
maxiter, mv_products | ||
) | ||
return CGIterable(A, x, r, c, u, tolerance, residual, | ||
conjugate_dot ? one(real(eltype(r))) : one(eltype(r)), # for conjugated dot, ρ_prev remains real | ||
maxiter, mv_products, Val(conjugate_dot)) | ||
else | ||
return PCGIterable(Pl, A, x, r, c, u, | ||
tolerance, residual, one(eltype(x)), | ||
maxiter, mv_products | ||
) | ||
tolerance, residual, one(eltype(r)), | ||
maxiter, mv_products, Val(conjugate_dot)) | ||
end | ||
end | ||
|
||
|
@@ -211,6 +220,7 @@ function cg!(x, A, b; | |
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||
verbose::Bool = false, | ||
Pl = Identity(), | ||
conjugate_dot::Bool = true, | ||
kwargs...) | ||
history = ConvergenceHistory(partial = !log) | ||
history[:abstol] = abstol | ||
|
@@ -219,7 +229,7 @@ function cg!(x, A, b; | |
|
||
# Actually perform CG | ||
iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter, | ||
statevars = statevars, kwargs...) | ||
statevars = statevars, conjugate_dot = conjugate_dot, kwargs...) | ||
if log | ||
history.mvps = iterable.mv_products | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,10 +21,26 @@ ldiv!(y, P::JacobiPrec, x) = y .= x ./ P.diagonal | |
|
||
Random.seed!(1234321) | ||
|
||
@testset "Vector{$T}, conjugated and unconjugated dot products" for T in (ComplexF32, ComplexF64) | ||
n = 100 | ||
x = rand(T, n) | ||
y = rand(T, n) | ||
|
||
# Conjugated dot product | ||
@test IterativeSolvers._dot(x, Val(true)) ≈ x'x | ||
@test IterativeSolvers._dot(x, y, Val(true)) ≈ x'y | ||
@test IterativeSolvers._dot(x, Val(true)) ≈ IterativeSolvers._dot(x, x, Val(true)) | ||
|
||
# Unonjugated dot product | ||
@test IterativeSolvers._dot(x, Val(false)) ≈ transpose(x) * x | ||
@test IterativeSolvers._dot(x, y, Val(false)) ≈ transpose(x) * y | ||
@test IterativeSolvers._dot(x, Val(false)) ≈ IterativeSolvers._dot(x, x, Val(false)) | ||
end | ||
|
||
@testset "Small full system" begin | ||
n = 10 | ||
|
||
@testset "Matrix{$T}" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
@testset "Matrix{$T}, conjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
A = rand(T, n, n) | ||
A = A' * A + I | ||
b = rand(T, n) | ||
|
@@ -50,6 +66,37 @@ Random.seed!(1234321) | |
x0 = cg(A, zeros(T, n)) | ||
@test x0 == zeros(T, n) | ||
end | ||
|
||
@testset "Matrix{$T}, unconjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
A = rand(T, n, n) | ||
A = A + transpose(A) + 15I | ||
x = ones(T, n) | ||
b = A * x | ||
|
||
reltol = √eps(real(T)) | ||
|
||
# Solve without preconditioner | ||
x1, his1 = cg(A, b, reltol = reltol, maxiter = 100, log = true, conjugate_dot = false) | ||
@test isa(his1, ConvergenceHistory) | ||
@test norm(A * x1 - b) / norm(b) ≤ reltol | ||
|
||
# With an initial guess | ||
x_guess = rand(T, n) | ||
x2, his2 = cg!(x_guess, A, b, reltol = reltol, maxiter = 100, log = true, conjugate_dot = false) | ||
@test isa(his2, ConvergenceHistory) | ||
@test x2 == x_guess | ||
@test norm(A * x2 - b) / norm(b) ≤ reltol | ||
|
||
# The following tests fails CI on Windows and Ubuntu due to a | ||
# `SingularException(4)` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's going on with this failure? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the enclosing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No failure during CI (see below). I will probably open another PR to remove the |
||
if T == Float32 && (Sys.iswindows() || Sys.islinux()) | ||
continue | ||
end | ||
# Do an exact LU decomp of a nearby matrix | ||
F = lu(A + rand(T, n, n)) | ||
x3, his3 = cg(A, b, Pl = F, maxiter = 100, reltol = reltol, log = true, conjugate_dot = false) | ||
@test norm(A * x3 - b) / norm(b) ≤ reltol | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For tests like this you can use @test A*x3 ≈ b rtol=reltol (see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm. I cannot seem to specify the keyword arguments of
Is this a new capability introduced in version > 1.5.4? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use the comma, and have it in a julia> @test 0 ≈ 0 rtol=1e-8
Test Passed |
||
end | ||
end | ||
|
||
@testset "Sparse Laplacian" begin | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to just use
_norm
/norm
; that's what was used before and norm is slightly more stable. Also I'm not sure ifsum(abs2, x)
works as efficiently on GPUs,norm
is definitely optimized.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "norm is more stable"? If you are talking about the possibility of overflow, I think it doesn't matter because the calculated norm is squared again in the algorithm.
Also, for the unconjugated dot product, the
norm(x)
function would returnsqrt(xᵀx)
, which is not a norm becausexᵀx
is complex. In COCG the quantity we use isxᵀx
, notsqrt(xᵀx)
, so I don't feel that it is a good idea to store the square-rooted quantity and then to square it again when using it.I verify that the GPU performance of
sum
does not degrade.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main advantage of falling back to
norm
anddot
in the conjugated case is that they generalize better — if the user has defined some specialized Banach-space type (and corresponding self-adjoint linear operatorsA
with overloaded*
), they should have overloadednorm
anddot
, whereassum(abs2, x)
might no longer work.(Even as simple a generalization as an array of arrays will fail with
sum(abs2, x)
, whereas they work withdot
andnorm
.)The overhead of the additional
sqrt
should be negligible.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because it dispatches to BLAS which does a stable norm computation. But generally BLAS libs are free to choose how to implement
norm
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very convincing argument. Additionally, considering that a user-defined Banach-space type should overload
norm
anddot
but not the unconjugated dot, it will be nice to be able to define the unconjugated dot usingdot
likebut of course this is inefficient as it conjugates
x
twice, not to mention extra allocations. I wishdotu
was exported, as you suggested in https://github.com/JuliaLang/julia/issues/22227#issuecomment-306224429.Users can overload
_dot(x, y, ::UnconjugatedDot)
for the Banach-space type they define, but what would be the best way to minimize such a requirement? I have pushed a commit implementing the unconjugated dot withsum
, but if there is a better approach, please let me know.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could implement the unconjugated dot as
transpose(x) * y
... this should have no extra allocations sincetranspose
is a "view" type by default?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though
transpose(x) * y
wouldn't work ifx
andy
are matrices (andA
is some kind of multlinear operator); maybe it's best to stick with thesum(zip(x,y))
solution you have now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we assume that
getindex()
is always defined forx
andy
? Then how about usingtranspose(@view(x[:])) * @view(y[:])
? This uses allocations, but it seems much faster thansum(prod, zip(x,y))
:It is nearly on a par with
dot
:Also, when
x
andy
areAbstractVectors
, the proposed method does not use allocations:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do
dot(transpose(x'), x)
, maybe?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know that
x'
was nonallocating! This looked like a great nonallocating implementation, but it turns out that this is slower than the earlier proposed solution using@view
:Also, neither
x'
nortranspose(x)
works forx
whose dimension is greater than 2.I just pushed an implementation using
@view
for now.