Skip to content

Commit

Permalink
Fix element types (#163)
Browse files Browse the repository at this point in the history
* Remove unused methods

* Clean up common.jl

* Ensure r has the type of x, not b. Remove b from the iterable since it is not used anyway

* Make residual type in BiCGStab(l) equal to solution type and remove rhs from iterable

* Chebyshev: residual type should equal x's type and remove b from iterable

* Similar story for GMRES

* Use x element type everywhere

* Use different types for solution, temporary and rhs in iterables of stationary methods

* zerox over zeros
  • Loading branch information
haampie authored and andreasnoack committed Aug 28, 2017
1 parent 1d036a3 commit 0aa0760
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 118 deletions.
11 changes: 5 additions & 6 deletions src/bicgstabl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ export bicgstabl, bicgstabl!, bicgstabl_iterator, bicgstabl_iterator!, BiCGStabI

import Base: start, next, done

mutable struct BiCGStabIterable{precT, matT, vecT <: AbstractVector, smallMatT <: AbstractMatrix, realT <: Real, scalarT <: Number}
mutable struct BiCGStabIterable{precT, matT, solT, vecT <: AbstractVector, smallMatT <: AbstractMatrix, realT <: Real, scalarT <: Number}
A::matT
b::vecT
l::Int

x::vecT
x::solT
r_shadow::vecT
rs::smallMatT
us::smallMatT
Expand All @@ -33,7 +32,7 @@ function bicgstabl_iterator!(x, A, b, l::Int = 2;
initial_zero = false,
tol = sqrt(eps(real(eltype(b))))
)
T = eltype(b)
T = eltype(x)
n = size(A, 1)
mv_products = 0

Expand Down Expand Up @@ -69,7 +68,7 @@ function bicgstabl_iterator!(x, A, b, l::Int = 2;
# Stopping condition based on relative tolerance.
reltol = nrm * tol

BiCGStabIterable(A, b, l, x, r_shadow, rs, us,
BiCGStabIterable(A, l, x, r_shadow, rs, us,
max_mv_products, mv_products, reltol, nrm,
Pl,
γ, ω, σ, M
Expand All @@ -81,7 +80,7 @@ end
@inline done(it::BiCGStabIterable, iteration::Int) = it.mv_products it.max_mv_products || converged(it)

function next(it::BiCGStabIterable, iteration::Int)
T = eltype(it.b)
T = eltype(it.x)
L = 2 : it.l + 1

it.σ = -it.ω * it.σ
Expand Down
26 changes: 11 additions & 15 deletions src/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ import Base: start, next, done

export cg, cg!, CGIterable, PCGIterable, cg_iterator, cg_iterator!

mutable struct CGIterable{matT, vecT <: AbstractVector, numT <: Real}
mutable struct CGIterable{matT, solT, vecT, numT <: Real}
A::matT
x::vecT
b::vecT
x::solT
r::vecT
c::vecT
u::vecT
Expand All @@ -16,11 +15,10 @@ mutable struct CGIterable{matT, vecT <: AbstractVector, numT <: Real}
mv_products::Int
end

mutable struct PCGIterable{precT, matT, vecT <: AbstractVector, numT <: Real, paramT <: Number}
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number}
Pl::precT
A::matT
x::vecT
b::vecT
x::solT
r::vecT
c::vecT
u::vecT
Expand All @@ -46,7 +44,7 @@ function next(it::CGIterable, iteration::Int)
# u := r + βu (almost an axpy)
β = it.residual^2 / it.prev_residual^2
@blas! it.u *= β
@blas! it.u += one(eltype(it.b)) * it.r
@blas! it.u += one(eltype(it.u)) * it.r

# c = A * u
A_mul_B!(it.c, it.A, it.u)
Expand Down Expand Up @@ -76,7 +74,7 @@ function next(it::PCGIterable, iteration::Int)
# u := c + βu (almost an axpy)
β = it.ρ / ρ_prev
@blas! it.u *= β
@blas! it.u += one(eltype(it.b)) * it.c
@blas! it.u += one(eltype(it.u)) * it.c

# c = A * u
A_mul_B!(it.c, it.A, it.u)
Expand All @@ -102,7 +100,8 @@ function cg_iterator!(x, A, b, Pl = Identity();
initially_zero::Bool = false
)
u = zeros(x)
r = copy(b)
r = similar(x)
copy!(r, b)

# Compute r with an MV-product or not.
if initially_zero
Expand All @@ -120,16 +119,13 @@ function cg_iterator!(x, A, b, Pl = Identity();

# Return the iterable
if isa(Pl, Identity)
return CGIterable(A, x, b,
r, c, u,
return CGIterable(A, x, r, c, u,
reltol, residual, one(residual),
maxiter, mv_products
)
else
ρ = one(eltype(r))
return PCGIterable(Pl, A, x, b,
r, c, u,
reltol, residual, ρ,
return PCGIterable(Pl, A, x, r, c, u,
reltol, residual, one(eltype(x)),
maxiter, mv_products
)
end
Expand Down
15 changes: 7 additions & 8 deletions src/chebyshev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ import Base: next, start, done

export chebyshev, chebyshev!

mutable struct ChebyshevIterable{precT, matT, vecT, realT <: Real}
mutable struct ChebyshevIterable{precT, matT, solT, vecT, realT <: Real}
Pl::precT
A::matT
b::vecT

x::vecT
x::solT
r::vecT
u::vecT
c::vecT
Expand All @@ -28,7 +27,7 @@ start(::ChebyshevIterable) = 0
done(c::ChebyshevIterable, iteration::Int) = iteration c.maxiter || converged(c)

function next(cheb::ChebyshevIterable, iteration::Int)
T = eltype(cheb.u)
T = eltype(cheb.x)

solve!(cheb.c, cheb.Pl, cheb.r)

Expand Down Expand Up @@ -64,8 +63,9 @@ function chebyshev_iterable!(x, A, b, λmin::Real, λmax::Real;
λ_avg = (λmax + λmin) / 2
λ_diff = (λmax - λmin) / 2

T = eltype(b)
r = copy(b)
T = eltype(x)
r = similar(x)
copy!(r, b)
u = zeros(x)
c = similar(x)

Expand All @@ -82,8 +82,7 @@ function chebyshev_iterable!(x, A, b, λmin::Real, λmax::Real;
mv_products = 1
end

ChebyshevIterable(Pl, A, b,
x, r, u, c,
ChebyshevIterable(Pl, A, x, r, u, c,
zero(real(T)),
λ_avg, λ_diff,
resnorm, reltol, maxiter, mv_products
Expand Down
41 changes: 2 additions & 39 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,60 +12,23 @@ Determine type of the division of an element of `b` against an element of `A`:
"""
Adivtype(A, b) = typeof(one(eltype(b))/one(eltype(A)))

"""
Amultype(A, x)
Determine type of the multiplication of an element of `b` with an element of `A`:
`typeof(one(eltype(A))*one(eltype(x)))`
"""
Amultype(A, x) = typeof(one(eltype(A))*one(eltype(x)))

"""
randx(A, b)
Build a random unitary vector `Vector{T}`, where `T` is `Adivtype(A,b)`.
"""
function randx(A, b)
T = Adivtype(A, b)
x = initrand!(Array(T, size(A, 2)))
end

"""
zerox(A, b)
Build a zeros vector `Vector{T}`, where `T` is `Adivtype(A,b)`.
"""
function zerox(A, b)
T = Adivtype(A, b)
x = zeros(T, size(A, 2))
end
zerox(A, b) = zeros(Adivtype(A, b), size(A, 2))

#### Numerics
"""
solve(A,b)
Solve `A\b` with a direct solver. When `A` is a function `A(b)` is dispatched instead.
Solve `A\\b` with a direct solver. When `A` is a function `A(b)` is dispatched instead.
"""
solve(A::Function,b) = A(b)

solve(A,b) = A\b

solve!(out::AbstractArray{T},A::Int,b::AbstractArray{T}) where {T} = scale!(out,b, 1/A)

solve!(out::AbstractArray{T},A,b::AbstractArray{T}) where {T} = A_ldiv_B!(out,A,b)
solve!(out::AbstractArray{T},A::Function,b::AbstractArray{T}) where {T} = copy!(out,A(b))

"""
initrand!(v)
Overwrite `v` with a random unitary vector of the same length.
"""
function initrand!(v::Vector)
_randn!(v)
nv = norm(v)
for i = 1:length(v)
v[i] /= nv
end
v
end
_randn!(v::Array{Float64}) = randn!(v)
_randn!(v) = copy!(v, randn(length(v)))

# Identity preconditioner
struct Identity end

Expand Down
18 changes: 9 additions & 9 deletions src/gmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ Residual(order, T::Type) = Residual{T, real(T)}(
one(real(T))
)

mutable struct GMRESIterable{preclT, precrT, vecT <: AbstractVector, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real}
mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real}
Pl::preclT
Pr::precrT
x::vecT
b::vecT
x::solT
b::rhsT
Ax::vecT # Some room to work in.

arnoldi::arnoldiT
Expand Down Expand Up @@ -98,25 +98,25 @@ function next(g::GMRESIterable, iteration::Int)
g.residual.current, iteration + 1
end

gmres_iterable(A, b; kwargs...) = gmres_iterable!(zeros(b), A, b; initially_zero = true, kwargs...)
gmres_iterable(A, b; kwargs...) = gmres_iterable!(zerox(A, b), A, b; initially_zero = true, kwargs...)

function gmres_iterable!(x, A, b;
Pl = Identity(),
Pr = Identity(),
tol = sqrt(eps(real(eltype(b)))),
restart::Int = min(20, length(b)),
maxiter::Int = restart,
initially_zero = false
initially_zero::Bool = false
)
T = eltype(b)
T = eltype(x)

# Approximate solution
arnoldi = ArnoldiDecomp(A, restart, T)
residual = Residual(restart, T)
mv_products = initially_zero == true ? 1 : 0
mv_products = initially_zero ? 1 : 0

# Workspace vector to reduce the # allocs.
Ax = similar(b)
Ax = similar(x)
residual.current = init!(arnoldi, x, b, Pl, Ax, initially_zero = initially_zero)
init_residual!(residual, residual.current)

Expand All @@ -133,7 +133,7 @@ end
Same as [`gmres!`](@ref), but allocates a solution vector `x` initialized with zeros.
"""
gmres(A, b; kwargs...) = gmres!(zeros(b), A, b; initially_zero = true, kwargs...)
gmres(A, b; kwargs...) = gmres!(zerox(A, b), A, b; initially_zero = true, kwargs...)

"""
gmres!(x, A, b; kwargs...) -> x, [history]
Expand Down
19 changes: 10 additions & 9 deletions src/minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ export minres_iterable, minres, minres!
import Base.LinAlg: BLAS.axpy!, givensAlgorithm
import Base: start, next, done

mutable struct MINRESIterable{matT, vecT <: DenseVector, smallVecT <: DenseVector, rotT <: Number, realT <: Real}
mutable struct MINRESIterable{matT, solT, vecT <: DenseVector, smallVecT <: DenseVector, rotT <: Number, realT <: Real}
A::matT
skew_hermitian::Bool
x::vecT
x::solT

# Krylov basis vectors
v_prev::vecT
Expand Down Expand Up @@ -44,15 +44,16 @@ function minres_iterable!(x, A, b;
tol = sqrt(eps(real(eltype(b)))),
maxiter = size(A, 1)
)
T = eltype(b)
T = eltype(x)
HessenbergT = skew_hermitian ? T : real(T)

v_prev = similar(b)
v_curr = copy(b)
v_next = similar(b)
w_prev = similar(b)
w_curr = similar(b)
w_next = similar(b)
v_prev = similar(x)
v_curr = similar(x)
copy!(v_curr, b)
v_next = similar(x)
w_prev = similar(x)
w_curr = similar(x)
w_next = similar(x)

mv_products = 0

Expand Down
24 changes: 12 additions & 12 deletions src/stationary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ function jacobi!(x, A::AbstractMatrix, b; maxiter::Int=10)
x
end

mutable struct DenseJacobiIterable{matT,vecT}
mutable struct DenseJacobiIterable{matT,vecT,solT,rhsT}
A::matT
x::vecT
x::solT
next::vecT
b::vecT
b::rhsT
maxiter::Int
end

Expand Down Expand Up @@ -93,10 +93,10 @@ function gauss_seidel!(x, A::AbstractMatrix, b; maxiter::Int=10)
x
end

mutable struct DenseGaussSeidelIterable{matT,vecT}
mutable struct DenseGaussSeidelIterable{matT,solT,rhsT}
A::matT
x::vecT
b::vecT
x::solT
b::rhsT
maxiter::Int
end

Expand Down Expand Up @@ -149,11 +149,11 @@ function sor!(x, A::AbstractMatrix, b, ω::Real; maxiter::Int=10)
x
end

mutable struct DenseSORIterable{matT,vecT,numT}
mutable struct DenseSORIterable{matT,solT,vecT,rhsT,numT}
A::matT
x::vecT
x::solT
tmp::vecT
b::vecT
b::rhsT
ω::numT
maxiter::Int
end
Expand Down Expand Up @@ -207,11 +207,11 @@ function ssor!(x, A::AbstractMatrix, b, ω::Real; maxiter::Int=10)
x
end

mutable struct DenseSSORIterable{matT,vecT,numT}
mutable struct DenseSSORIterable{matT,solT,vecT,rhsT,numT}
A::matT
x::vecT
x::solT
tmp::vecT
b::vecT
b::rhsT
ω::numT
maxiter::Int
end
Expand Down
Loading

0 comments on commit 0aa0760

Please sign in to comment.