Skip to content

Commit

Permalink
Use metaprogramming to generate more out-of-place methods
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 14, 2024
1 parent a4cc680 commit dfcd1cd
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 127 deletions.
21 changes: 0 additions & 21 deletions src/diom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,6 @@ optargs_diom = (:x0,)
kwargs_diom = (:M, :N, :ldiv, :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function diom($(def_args_diom...), $(def_optargs_diom...); memory :: Int=20, $(def_kwargs_diom...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = DiomSolver(A, b, memory)
warm_start!(solver, $(optargs_diom...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
diom!(solver, $(args_diom...); $(kwargs_diom...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function diom($(def_args_diom...); memory :: Int=20, $(def_kwargs_diom...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = DiomSolver(A, b, memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
diom!(solver, $(args_diom...); $(kwargs_diom...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function diom!(solver :: DiomSolver{T,FC,S}, $(def_args_diom...); $(def_kwargs_diom...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/dqgmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,6 @@ optargs_dqgmres = (:x0,)
kwargs_dqgmres = (:M, :N, :ldiv, :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function dqgmres($(def_args_dqgmres...), $(def_optargs_dqgmres...); memory :: Int=20, $(def_kwargs_dqgmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = DqgmresSolver(A, b, memory)
warm_start!(solver, $(optargs_dqgmres...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
dqgmres!(solver, $(args_dqgmres...); $(kwargs_dqgmres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function dqgmres($(def_args_dqgmres...); memory :: Int=20, $(def_kwargs_dqgmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = DqgmresSolver(A, b, memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
dqgmres!(solver, $(args_dqgmres...); $(kwargs_dqgmres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function dqgmres!(solver :: DqgmresSolver{T,FC,S}, $(def_args_dqgmres...); $(def_kwargs_dqgmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/fgmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,6 @@ optargs_fgmres = (:x0,)
kwargs_fgmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function fgmres($(def_args_fgmres...), $(def_optargs_fgmres...); memory :: Int=20, $(def_kwargs_fgmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = FgmresSolver(A, b, memory)
warm_start!(solver, $(optargs_fgmres...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
fgmres!(solver, $(args_fgmres...); $(kwargs_fgmres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function fgmres($(def_args_fgmres...); memory :: Int=20, $(def_kwargs_fgmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = FgmresSolver(A, b, memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
fgmres!(solver, $(args_fgmres...); $(kwargs_fgmres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function fgmres!(solver :: FgmresSolver{T,FC,S}, $(def_args_fgmres...); $(def_kwargs_fgmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/fom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,6 @@ optargs_fom = (:x0,)
kwargs_fom = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function fom($(def_args_fom...), $(def_optargs_fom...); memory :: Int=20, $(def_kwargs_fom...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = FomSolver(A, b, memory)
warm_start!(solver, $(optargs_fom...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
fom!(solver, $(args_fom...); $(kwargs_fom...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function fom($(def_args_fom...); memory :: Int=20, $(def_kwargs_fom...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = FomSolver(A, b, memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
fom!(solver, $(args_fom...); $(kwargs_fom...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function fom!(solver :: FomSolver{T,FC,S}, $(def_args_fom...); $(def_kwargs_fom...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/gmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,6 @@ optargs_gmres = (:x0,)
kwargs_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function gmres($(def_args_gmres...), $(def_optargs_gmres...); memory :: Int=20, $(def_kwargs_gmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = GmresSolver(A, b, memory)
warm_start!(solver, $(optargs_gmres...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
gmres!(solver, $(args_gmres...); $(kwargs_gmres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function gmres($(def_args_gmres...); memory :: Int=20, $(def_kwargs_gmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = GmresSolver(A, b, memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
gmres!(solver, $(args_gmres...); $(kwargs_gmres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function gmres!(solver :: GmresSolver{T,FC,S}, $(def_args_gmres...); $(def_kwargs_gmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/gpmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,27 +148,6 @@ optargs_gpmr = (:x0, :y0)
kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, , , :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function gpmr($(def_args_gpmr...), $(def_optargs_gpmr...); memory :: Int=20, $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = GpmrSolver(A, b, memory)
warm_start!(solver, $(optargs_gpmr...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
gpmr!(solver, $(args_gpmr...); $(kwargs_gpmr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.y, solver.stats)
end

function gpmr($(def_args_gpmr...); memory :: Int=20, $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = GpmrSolver(A, b, memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
gpmr!(solver, $(args_gpmr...); $(kwargs_gpmr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.y, solver.stats)
end

function gpmr!(solver :: GpmrSolver{T,FC,S}, $(def_args_gpmr...); $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
27 changes: 26 additions & 1 deletion src/krylov_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ for (KS, fun, fun2, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in
(:CglsLanczosShiftSolver, :cgls_lanczos_shift!, :cgls_lanczos_shift, args_cgls_lanczos_shift, def_args_cgls_lanczos_shift, (), (), kwargs_cgls_lanczos_shift, def_kwargs_cgls_lanczos_shift)
]
# window :: 5 -> lslq, lsmr, lsqr, minres, symmlq
# memory :: 20 -> diom, dqgmres, fom, gmres, fgmres, gpmr
if fun2 in (:cg_lanczos_shift, :cgls_lanczos_shift)
@eval begin
## Out-of-place
Expand All @@ -60,6 +59,32 @@ for (KS, fun, fun2, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in
return results(solver)
end
end
elseif fun2 in (:diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr)
@eval begin
## Out-of-place
function $(fun2)($(def_args...); memory::Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $KS(A, b, memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
$(fun)(solver, $(args...); $(kwargs...))
solver.stats.timer += elapsed_time
return results(solver)
end

if !isempty($optargs)
function $(fun2)($(def_args...), $(def_optargs...); memory::Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $KS(A, b, memory)
warm_start!(solver, $(optargs...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
$(fun)(solver, $(args...); $(kwargs...))
solver.stats.timer += elapsed_time
return results(solver)
end
end
end
elseif fun2 (:cg_lanczos_shift, :cgls_lanczos_shift, :lslq, :lsmr, :lsqr, :minres, :symmlq, :diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr)
@eval begin
## Out-of-place
Expand Down

0 comments on commit dfcd1cd

Please sign in to comment.