Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add COCG method for complex symmetric linear systems #289

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions src/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,29 @@ import Base: iterate
using Printf
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables

mutable struct CGIterable{matT, solT, vecT, numT <: Real}
# Conjugated dot product
_dot(x, ::Val{true}) = sum(abs2, x) # for x::Complex, returns Real
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to just use _norm / norm; that's what was used before and norm is slightly more stable. Also I'm not sure if sum(abs2, x) works as efficiently on GPUs, norm is definitely optimized.

Copy link
Author

@wsshin wsshin Jan 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "norm is more stable"? If you are talking about the possibility of overflow, I think it doesn't matter because the calculated norm is squared again in the algorithm.

Also, for the unconjugated dot product, the norm(x) function would return sqrt(xᵀx), which is not a norm because xᵀx is complex. In COCG the quantity we use is xᵀx, not sqrt(xᵀx), so I don't feel that it is a good idea to store the square-rooted quantity and then to square it again when using it.

I verify that the GPU performance of sum does not degrade.

julia> using LinearAlgebra, CuArrays, BenchmarkTools

julia> v = rand(ComplexF64, 10^9);

julia> cv = cu(v);

julia> @btime norm($v);
  1.890 s (0 allocations: 0 bytes)

julia> @btime sum(abs2, $v);
  1.452 s (0 allocations: 0 bytes)  # sum() is significantly faster than norm() on CPU

julia> @btime norm($cv);
  18.341 ms (1 allocation: 16 bytes)

julia> @btime sum(abs2, $cv);
  18.341 ms (361 allocations: 13.94 KiB)  # sum is on par with norm() on GPU

Copy link
Member

@stevengj stevengj Mar 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main advantage of falling back to norm and dot in the conjugated case is that they generalize better — if the user has defined some specialized Banach-space type (and corresponding self-adjoint linear operators A with overloaded *), they should have overloaded norm and dot, whereas sum(abs2, x) might no longer work.

(Even as simple a generalization as an array of arrays will fail with sum(abs2, x), whereas they work with dot and norm.)

The overhead of the additional sqrt should be negligible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.452 s (0 allocations: 0 bytes) # sum() is significantly faster than norm() on CPU

This is because it dispatches to BLAS which does a stable norm computation. But generally BLAS libs are free to choose how to implement norm.

Copy link
Author

@wsshin wsshin Mar 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main advantage of falling back to norm and dot in the conjugated case is that they generalize better — if the user has defined some specialized Banach-space type (and corresponding self-adjoint linear operators A with overloaded *), they should have overloaded norm and dot, whereas sum(abs2, x) might no longer work.

(Even as simple a generalization as an array of arrays will fail with sum(abs2, x), whereas they work with dot and norm.)

This is a very convincing argument. Additionally, considering that a user-defined Banach-space type should overload norm and dot but not the unconjugated dot, it will be nice to be able to define the unconjugated dot using dot like

_dot(x, y, ::UnconjugatedDot) = dot(conj(x), y)

but of course this is inefficient as it conjugates x twice, not to mention extra allocations. I wish dotu was exported, as you suggested in https://github.com/JuliaLang/julia/issues/22227#issuecomment-306224429.

Users can overload _dot(x, y, ::UnconjugatedDot) for the Banach-space type they define, but what would be the best way to minimize such a requirement? I have pushed a commit implementing the unconjugated dot with sum, but if there is a better approach, please let me know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could implement the unconjugated dot as transpose(x) * y ... this should have no extra allocations since transpose is a "view" type by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though transpose(x) * y wouldn't work if x and y are matrices (and A is some kind of multlinear operator); maybe it's best to stick with the sum(zip(x,y)) solution you have now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assume that getindex() is always defined for x and y? Then how about using transpose(@view(x[:])) * @view(y[:])? This uses allocations, but it seems much faster than sum(prod, zip(x,y)):

julia> x = rand(100,100); y = rand(100,100);

julia> @btime sum(prod, zip($x,$y));
  12.564 μs (0 allocations: 0 bytes)

julia> @btime transpose(@view($x[:])) * @view($y[:]);
  1.966 μs (4 allocations: 160 bytes)

It is nearly on a par with dot:

julia> @btime dot($x,$y);
  1.839 μs (0 allocations: 0 bytes)

Also, when x and y are AbstractVectors, the proposed method does not use allocations:

julia> x = rand(1000); y = rand(1000);

julia> @btime transpose(@view($x[:])) * @view($y[:]);
  108.091 ns (0 allocations: 0 bytes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do dot(transpose(x'), x), maybe?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that x' was nonallocating! This looked like a great nonallocating implementation, but it turns out that this is slower than the earlier proposed solution using @view:

julia> @btime transpose(@view($x[:])) * @view($y[:]);
  267.794 ns (0 allocations: 0 bytes)

julia> @btime dot(transpose($x'), $y);
  1.547 μs (0 allocations: 0 bytes)

Also, neither x' nor transpose(x) works for x whose dimension is greater than 2.

I just pushed an implementation using @view for now.

_dot(x, y, ::Val{true}) = dot(x, y)

# Unconjugated dot product
_dot(x, ::Val{false}) = sum(xₖ^2 for xₖ in x)
_dot(x, y, ::Val{false}) = sum(prod, zip(x,y))

mutable struct CGIterable{matT, solT, vecT, numT <: Real, paramT <: Number, boolT <: Union{Val{true},Val{false}}}
A::matT
x::solT
r::vecT
c::vecT
u::vecT
tol::numT
residual::numT
prev_residual::numT
ρ_prev::paramT
maxiter::Int
mv_products::Int
conjugate_dot::boolT
end

mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number}
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number, boolT <: Union{Val{true},Val{false}}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you turn it into two named structs instead of Val{true/false}, it would be easier to read.

Copy link
Author

@wsshin wsshin Jan 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand defining InnerProduct and UnconjugatedProduct was your original suggestion, but I thought passing a boolean flag in the top-level user interface like

cg!(x, A, b; conjugate_dot=true)

would be easier for the users than something like

cg!(x, A, b; innerprodtype=UnconjugatedProduct())

Introduction of a new type should be minimized, because it reduces portability. Also, I don't think using a boolean flag reduces readability in this case, because the flag is a keyword argument.

If we agree that passing a boolean flag conjugate_dot to the top interface is a slightly better choice, it would be more natural to do

CGIterable(A, x, r, ..., Val(conjugate_dot))

than

CGIterable(A, x, r, ..., conjugate_dot ? InnerProduct() : UnconjugatedProduct())

This is why I decided to include a boolean type in CGIterable than a completely new type as you suggested originally. Please let me know what you think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking the high-level interface could just be cocg! instead of cg! with a flag. Would be nice if the methods would be type stable without relying on constant propagation of conjugate_dot.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduction of a new type should be minimized, because it reduces portability.

I'm not sure what you mean by "reducing portability," Wonseok?

A new type has the advantage of allowing more than two possibilities. For example, one could imagine a WeightedDot(B) that corresponds to the dot(x,B,y) dot product — this is useful in cases where A is Hermitian with respect to some nonstandard dot product (the alternative being to perform a change of variables via Cholesky factors of B, which may not always be convenient).

conjugate_dot=true has the disadvantage of potentially requiring a dynamic dispatch at the top level. (Unless the compiler does constant propagation, but I think it only does that if the function is short enough to be inlined, which I don't think cg! is?) That's not such a big deal here, however, because overhead of a single dynamic dispatch should be negligible for large problems where iterative solvers are used.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by "reducing portability," Wonseok?

Please ignore my comment on portability. I think it was irrelevant to the current situation. (For clarification and my future reference, I meant it is not desirable to define a package-specific collection type that aggregates various fields and to pass such a collection to a function in order to reduce the number of function arguments.)

A new type has the advantage of allowing more than two possibilities. For example, one could imagine a WeightedDot(B) that corresponds to the dot(x,B,y) dot product — this is useful in cases where A is Hermitian with respect to some nonstandard dot product (the alternative being to perform a change of variables via Cholesky factors of B, which may not always be convenient).

I really like this flexibility!

conjugate_dot=true has the disadvantage of potentially requiring a dynamic dispatch at the top level. (Unless the compiler does constant propagation, but I think it only does that if the function is short enough to be inlined, which I don't think cg! is?) That's not such a big deal here, however, because overhead of a single dynamic dispatch should be negligible for large problems where iterative solvers are used.

I didn't know that constant propagation can be a means to avoid dynamic dispatch. Glad to learn a new thing. Thanks for the information! (I find this useful on this topic.)

But before knowing about constant propagation, I thought the the overhead of a dynamic dispatch should be negligible when iterative solvers are useful, as you pointed out.

I have pushed a commit complying with this request.

Pl::precT
A::matT
x::solT
Expand All @@ -24,9 +33,10 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb
u::vecT
tol::numT
residual::numT
ρ::paramT
ρ_prev::paramT
maxiter::Int
mv_products::Int
conjugate_dot::boolT
end

@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol
Expand All @@ -47,18 +57,19 @@ function iterate(it::CGIterable, iteration::Int=start(it))
end

# u := r + βu (almost an axpy)
β = it.residual^2 / it.prev_residual^2
ρ = _dot(it.r, it.conjugate_dot)
Copy link
Member

@haampie haampie Jan 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CG and preconditioned CG are split into separate methods because CG requires one fewer dot product. Can you check if you can drop this dot in your version somehow?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but I'm afraid we can't. It turns out that the saving of one dot product in the non-preconditioned CG comes from the fact that we calculate ρ = r⋅r and the norm of the residual vector can be calculated by sqrt(ρ) instead of performing the dot product again.

In COCG, we use the unconjugated dot product, so ρ = transpose(r)*r. Therefore, the norm of the residual vector cannot be calculated by sqrt(ρ), but needs to be calculated by sqrt(r⋅r).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case you should probably have

β = it.conjugate_dot isa Val{true} ? it.residual^2 / it.prev_residual^2 : _dot(it.r, it.conjugate_dot)

so that it only does the extra dot product in the unconjugated case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if you do that please change Val{true} into a named type to improve the readability a bit

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

β = ρ / it.ρ_prev
it.u .= it.r .+ β .* it.u

# c = A * u
mul!(it.c, it.A, it.u)
α = it.residual^2 / dot(it.u, it.c)
α = ρ / _dot(it.u, it.c, it.conjugate_dot)

# Improve solution and residual
it.ρ_prev = ρ
it.x .+= α .* it.u
it.r .-= α .* it.c

it.prev_residual = it.residual
it.residual = norm(it.r)

# Return the residual at item and iteration number as state
Expand All @@ -78,18 +89,17 @@ function iterate(it::PCGIterable, iteration::Int=start(it))
# Apply left preconditioner
ldiv!(it.c, it.Pl, it.r)

ρ_prev = it.ρ
it.ρ = dot(it.c, it.r)

# u := c + βu (almost an axpy)
β = it.ρ / ρ_prev
ρ = _dot(it.r, it.c, it.conjugate_dot)
β = ρ / it.ρ_prev
it.u .= it.c .+ β .* it.u

# c = A * u
mul!(it.c, it.A, it.u)
α = it.ρ / dot(it.u, it.c)
α = ρ / _dot(it.u, it.c, it.conjugate_dot)

# Improve solution and residual
it.ρ_prev = ρ
it.x .+= α .* it.u
it.r .-= α .* it.c

Expand Down Expand Up @@ -122,7 +132,8 @@ function cg_iterator!(x, A, b, Pl = Identity();
reltol::Real = sqrt(eps(real(eltype(b)))),
maxiter::Int = size(A, 2),
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
initially_zero::Bool = false)
initially_zero::Bool = false,
conjugate_dot::Bool = true)
u = statevars.u
r = statevars.r
c = statevars.c
Expand All @@ -142,15 +153,13 @@ function cg_iterator!(x, A, b, Pl = Identity();

# Return the iterable
if isa(Pl, Identity)
return CGIterable(A, x, r, c, u,
tolerance, residual, one(residual),
maxiter, mv_products
)
return CGIterable(A, x, r, c, u, tolerance, residual,
conjugate_dot ? one(real(eltype(r))) : one(eltype(r)), # for conjugated dot, ρ_prev remains real
maxiter, mv_products, Val(conjugate_dot))
else
return PCGIterable(Pl, A, x, r, c, u,
tolerance, residual, one(eltype(x)),
maxiter, mv_products
)
tolerance, residual, one(eltype(r)),
maxiter, mv_products, Val(conjugate_dot))
end
end

Expand Down Expand Up @@ -211,6 +220,7 @@ function cg!(x, A, b;
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
verbose::Bool = false,
Pl = Identity(),
conjugate_dot::Bool = true,
kwargs...)
history = ConvergenceHistory(partial = !log)
history[:abstol] = abstol
Expand All @@ -219,7 +229,7 @@ function cg!(x, A, b;

# Actually perform CG
iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter,
statevars = statevars, kwargs...)
statevars = statevars, conjugate_dot = conjugate_dot, kwargs...)
if log
history.mvps = iterable.mv_products
end
Expand Down
49 changes: 48 additions & 1 deletion test/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,26 @@ ldiv!(y, P::JacobiPrec, x) = y .= x ./ P.diagonal

Random.seed!(1234321)

@testset "Vector{$T}, conjugated and unconjugated dot products" for T in (ComplexF32, ComplexF64)
n = 100
x = rand(T, n)
y = rand(T, n)

# Conjugated dot product
@test IterativeSolvers._dot(x, Val(true)) ≈ x'x
@test IterativeSolvers._dot(x, y, Val(true)) ≈ x'y
@test IterativeSolvers._dot(x, Val(true)) ≈ IterativeSolvers._dot(x, x, Val(true))

# Unonjugated dot product
@test IterativeSolvers._dot(x, Val(false)) ≈ transpose(x) * x
@test IterativeSolvers._dot(x, y, Val(false)) ≈ transpose(x) * y
@test IterativeSolvers._dot(x, Val(false)) ≈ IterativeSolvers._dot(x, x, Val(false))
end

@testset "Small full system" begin
n = 10

@testset "Matrix{$T}" for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "Matrix{$T}, conjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = rand(T, n, n)
A = A' * A + I
b = rand(T, n)
Expand All @@ -50,6 +66,37 @@ Random.seed!(1234321)
x0 = cg(A, zeros(T, n))
@test x0 == zeros(T, n)
end

@testset "Matrix{$T}, unconjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = rand(T, n, n)
A = A + transpose(A) + 15I
x = ones(T, n)
b = A * x

reltol = √eps(real(T))

# Solve without preconditioner
x1, his1 = cg(A, b, reltol = reltol, maxiter = 100, log = true, conjugate_dot = false)
@test isa(his1, ConvergenceHistory)
@test norm(A * x1 - b) / norm(b) ≤ reltol

# With an initial guess
x_guess = rand(T, n)
x2, his2 = cg!(x_guess, A, b, reltol = reltol, maxiter = 100, log = true, conjugate_dot = false)
@test isa(his2, ConvergenceHistory)
@test x2 == x_guess
@test norm(A * x2 - b) / norm(b) ≤ reltol

# The following tests fails CI on Windows and Ubuntu due to a
# `SingularException(4)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on with this failure?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the enclosing @testset of this @test is copied from L13-L43 of test/bicgstabl.jl. I didn't pay too much attention to the copied tests. Let me try to run the tests on a Windows box and get back to you.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did pkg> test IterativeSolvers in a Windows box with the continue block (L35-L37 of test/bicgstabl.jl) commented out, and all the tests run fine. I will push the commit without the continue block in test/cg.jl. Let's see if the tests succeed...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No failure during CI (see below). I will probably open another PR to remove the continue block from test/bicgstabl.jl.

if T == Float32 && (Sys.iswindows() || Sys.islinux())
continue
end
# Do an exact LU decomp of a nearby matrix
F = lu(A + rand(T, n, n))
x3, his3 = cg(A, b, Pl = F, maxiter = 100, reltol = reltol, log = true, conjugate_dot = false)
@test norm(A * x3 - b) / norm(b) ≤ reltol
Copy link
Member

@stevengj stevengj Mar 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tests like this you can use . It should be equivalent to:

@test A*x3  b rtol=reltol

(see isapprox).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. I cannot seem to specify the keyword arguments of isapprox in its infix form:

julia> VERSION
v"1.5.4-pre.0"

julia> 0 ≈ 0
true

julia> 0 ≈ 0, rtol=1e-8
ERROR: syntax: "0" is not a valid function argument name around REPL[23]:1

Is this a new capability introduced in version > 1.5.4?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use the comma, and have it in a @test:

julia> @test 0  0 rtol=1e-8
Test Passed

end
end

@testset "Sparse Laplacian" begin
Expand Down