diff --git a/src/cg_lanczos_shift.jl b/src/cg_lanczos_shift.jl index 43cebeeb8..f857bd2c1 100644 --- a/src/cg_lanczos_shift.jl +++ b/src/cg_lanczos_shift.jl @@ -94,17 +94,6 @@ args_cg_lanczos_shift = (:A, :b, :shifts) kwargs_cg_lanczos_shift = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream) @eval begin - function cg_lanczos_shift($(def_args_cg_lanczos_shift...); $(def_kwargs_cg_lanczos_shift...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} - start_time = time_ns() - nshifts = length(shifts) - solver = CgLanczosShiftSolver(A, b, nshifts) - elapsed_time = ktimer(start_time) - timemax -= elapsed_time - cg_lanczos_shift!(solver, $(args_cg_lanczos_shift...); $(kwargs_cg_lanczos_shift...)) - solver.stats.timer += elapsed_time - return (solver.x, solver.stats) - end - function cg_lanczos_shift!(solver :: CgLanczosShiftSolver{T,FC,S}, $(def_args_cg_lanczos_shift...); $(def_kwargs_cg_lanczos_shift...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} # Timer diff --git a/src/krylov_solve.jl b/src/krylov_solve.jl index f3fed1c41..1f6d54830 100644 --- a/src/krylov_solve.jl +++ b/src/krylov_solve.jl @@ -44,9 +44,22 @@ for (KS, fun, fun2, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in (:GpmrSolver , :gpmr! , :gpmr , args_gpmr , def_args_gpmr , optargs_gpmr , def_optargs_gpmr , kwargs_gpmr , def_kwargs_gpmr ) ] # window :: 5 -> lslq, lsmr, lsqr, minres, symmlq - # shifts -> CgLanczosShiftSolver(A, b, nshifts) # memory :: 20 -> diom, dqgmres, fom, gmres, fgmres, gpmr - if fun2 ∉ (:lslq, :lsmr, :lsqr, :minres, :symmlq, :cg_lanczos_shift, :diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr) + if fun2 == :cg_lanczos_shift + @eval begin + ## Out-of-place + function $(fun2)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + nshifts = length(shifts) + solver = $KS(A, b, nshifts) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + $(fun)(solver, $(args...); $(kwargs...)) + solver.stats.timer += elapsed_time + return results(solver) + end + end + elseif fun ∉ (:lslq, :lsmr, :lsqr, :minres, :symmlq, :diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr) @eval begin ## Out-of-place function $(fun2)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}