Skip to content

Commit

Permalink
Generate all out-of-place methods in krylov_solve.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 14, 2024
1 parent 3a1deba commit dca3d1f
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 78 deletions.
34 changes: 28 additions & 6 deletions src/krylov_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ for (KS, fun, fun2, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in
(:CgLanczosShiftSolver , :cg_lanczos_shift! , :cg_lanczos_shift , args_cg_lanczos_shift , def_args_cg_lanczos_shift , (), (), kwargs_cg_lanczos_shift , def_kwargs_cg_lanczos_shift )
(:CglsLanczosShiftSolver, :cgls_lanczos_shift!, :cgls_lanczos_shift, args_cgls_lanczos_shift, def_args_cgls_lanczos_shift, (), (), kwargs_cgls_lanczos_shift, def_kwargs_cgls_lanczos_shift)
]
# window :: 5 -> lslq, lsmr, lsqr, minres, symmlq
## Out-of-place
if fun2 in (:cg_lanczos_shift, :cgls_lanczos_shift)
@eval begin
## Out-of-place
function $(fun2)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
nshifts = length(shifts)
Expand All @@ -61,7 +60,6 @@ for (KS, fun, fun2, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in
end
elseif fun2 in (:diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr)
@eval begin
## Out-of-place
function $(fun2)($(def_args...); memory::Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $KS(A, b, memory)
Expand All @@ -85,9 +83,33 @@ for (KS, fun, fun2, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in
end
end
end
elseif fun2 (:cg_lanczos_shift, :cgls_lanczos_shift, :lslq, :lsmr, :lsqr, :minres, :symmlq, :diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr)
elseif fun2 in (:lslq, :lsmr, :lsqr, :minres, :symmlq)
@eval begin
function $(fun2)($(def_args...); window::Int=5, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $KS(A, b; window)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
$(fun)(solver, $(args...); $(kwargs...))
solver.stats.timer += elapsed_time
return results(solver)
end

if !isempty($optargs)
function $(fun2)($(def_args...), $(def_optargs...); window::Int=5, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $KS(A, b; window)
warm_start!(solver, $(optargs...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
$(fun)(solver, $(args...); $(kwargs...))
solver.stats.timer += elapsed_time
return results(solver)
end
end
end
else
@eval begin
## Out-of-place
function $(fun2)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = $KS(A, b)
Expand All @@ -113,8 +135,8 @@ for (KS, fun, fun2, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in
end
end

## In-place
@eval begin
## In-place
solve!(solver :: $KS{T,FC,S}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} = $(fun)(solver, $(args...); $(kwargs...))

if !isempty($optargs)
Expand Down
10 changes: 0 additions & 10 deletions src/lslq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,6 @@ args_lslq = (:A, :b)
kwargs_lslq = (:M, :N, :ldiv, :transfer_to_lsqr, :sqd, , , :etol, :utol, :btol, :conlim, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function lslq($(def_args_lslq...); window :: Int=5, $(def_kwargs_lslq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = LslqSolver(A, b; window)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
lslq!(solver, $(args_lslq...); $(kwargs_lslq...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function lslq!(solver :: LslqSolver{T,FC,S}, $(def_args_lslq...); $(def_kwargs_lslq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/lsmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,6 @@ args_lsmr = (:A, :b)
kwargs_lsmr = (:M, :N, :ldiv, :sqd, , :radius, :etol, :axtol, :btol, :conlim, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function lsmr($(def_args_lsmr...); window :: Int=5, $(def_kwargs_lsmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = LsmrSolver(A, b; window)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
lsmr!(solver, $(args_lsmr...); $(kwargs_lsmr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function lsmr!(solver :: LsmrSolver{T,FC,S}, $(def_args_lsmr...); $(def_kwargs_lsmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/lsqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,6 @@ args_lsqr = (:A, :b)
kwargs_lsqr = (:M, :N, :ldiv, :sqd, , :radius, :etol, :axtol, :btol, :conlim, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function lsqr($(def_args_lsqr...); window :: Int=5, $(def_kwargs_lsqr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = LsqrSolver(A, b; window)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
lsqr!(solver, $(args_lsqr...); $(kwargs_lsqr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function lsqr!(solver :: LsqrSolver{T,FC,S}, $(def_args_lsqr...); $(def_kwargs_lsqr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,6 @@ optargs_minres = (:x0,)
kwargs_minres = (:M, :ldiv, , :atol, :rtol, :etol, :conlim, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function minres($(def_args_minres...), $(def_optargs_minres...); window :: Int=5, $(def_kwargs_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = MinresSolver(A, b; window)
warm_start!(solver, $(optargs_minres...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
minres!(solver, $(args_minres...); $(kwargs_minres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function minres($(def_args_minres...); window :: Int=5, $(def_kwargs_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = MinresSolver(A, b; window)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
minres!(solver, $(args_minres...); $(kwargs_minres...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function minres!(solver :: MinresSolver{T,FC,S}, $(def_args_minres...); $(def_kwargs_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/symmlq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,6 @@ optargs_symmlq = (:x0,)
kwargs_symmlq = (:M, :ldiv, :transfer_to_cg, , :λest, :atol, :rtol, :etol, :conlim, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function symmlq($(def_args_symmlq...), $(def_optargs_symmlq...); window :: Int=5, $(def_kwargs_symmlq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = SymmlqSolver(A, b; window)
warm_start!(solver, $(optargs_symmlq...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
symmlq!(solver, $(args_symmlq...); $(kwargs_symmlq...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function symmlq($(def_args_symmlq...); window :: Int=5, $(def_kwargs_symmlq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = SymmlqSolver(A, b; window)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
symmlq!(solver, $(args_symmlq...); $(kwargs_symmlq...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function symmlq!(solver :: SymmlqSolver{T,FC,S}, $(def_args_symmlq...); $(def_kwargs_symmlq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down

0 comments on commit dca3d1f

Please sign in to comment.