diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 2bd131675f..33f7f0b02b 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -234,33 +234,6 @@ function DiffEqBase.prepare_alg(alg::Union{ OrdinaryDiffEqExponentialAlgorithm{0, AD, FDT}}, u0::AbstractArray{T}, p, prob) where {AD, FDT, T} - if alg isa OrdinaryDiffEqExponentialAlgorithm - linsolve = nothing - elseif alg.linsolve === nothing - if (prob.f isa ODEFunction && prob.f.f isa AbstractSciMLOperator) - linsolve = LinearSolve.defaultalg(prob.f.f, u0) - elseif (prob.f isa SplitFunction && - prob.f.f1.f isa AbstractSciMLOperator) - linsolve = LinearSolve.defaultalg(prob.f.f1.f, u0) - if (linsolve === nothing) || (linsolve isa LinearSolve.DefaultLinearSolver && - linsolve.alg !== LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) - msg = "Split ODE problem do not work with factorization linear solvers. Bug detailed in https://github.com/SciML/OrdinaryDiffEq.jl/pull/1643. Defaulting to linsolve=KrylovJL()" - @warn msg - linsolve = KrylovJL_GMRES() - end - elseif (prob isa ODEProblem || prob isa DDEProblem) && - (prob.f.mass_matrix === nothing || - (prob.f.mass_matrix !== nothing && - !(prob.f.jac_prototype isa AbstractSciMLOperator))) - linsolve = LinearSolve.defaultalg(prob.f.jac_prototype, u0) - else - # If mm is a sparse matrix and A is a MatrixOperator, then let linear - # solver choose things later - linsolve = nothing - end - else - linsolve = alg.linsolve - end # If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already) # don't use a large chunksize as it will either error or not be beneficial @@ -268,16 +241,11 @@ function DiffEqBase.prepare_alg(alg::Union{ (isbitstype(T) && sizeof(T) > 24) || (prob.f isa ODEFunction && prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) - if alg isa OrdinaryDiffEqExponentialAlgorithm - return remake(alg, chunk_size = Val{1}()) - else - return remake(alg, chunk_size = Val{1}(), linsolve = linsolve) - end + return remake(alg, chunk_size = Val{1}()) end L = StaticArrayInterface.known_length(typeof(u0)) if L === nothing # dynamic sized - # If chunksize is zero, pick chunksize right at the start of solve and # then do function barrier to infer the full solve x = if prob.f.colorvec === nothing @@ -287,19 +255,10 @@ function DiffEqBase.prepare_alg(alg::Union{ end cs = ForwardDiff.pickchunksize(x) - - if alg isa OrdinaryDiffEqExponentialAlgorithm - return remake(alg, chunk_size = Val{cs}()) - else - return remake(alg, chunk_size = Val{cs}(), linsolve = linsolve) - end + return remake(alg, chunk_size = Val{cs}()) else # statically sized cs = pick_static_chunksize(Val{L}()) - if alg isa OrdinaryDiffEqExponentialAlgorithm - return remake(alg, chunk_size = cs) - else - return remake(alg, chunk_size = cs, linsolve = linsolve) - end + return remake(alg, chunk_size = cs) end end diff --git a/src/algorithms.jl b/src/algorithms.jl index a8959e0606..4d8d0a7009 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -63,13 +63,11 @@ function DiffEqBase.remake(thing::Union{ ST, CJ}, OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ }, - DAEAlgorithm{CS, AD, FDT, ST, CJ}}; - linsolve, kwargs...) where {CS, AD, FDT, ST, CJ} + DAEAlgorithm{CS, AD, FDT, ST, CJ}}; kwargs...) where {CS, AD, FDT, ST, CJ} T = SciMLBase.remaker_of(thing) T(; SciMLBase.struct_as_namedtuple(thing)..., chunk_size = Val{CS}(), autodiff = Val{AD}(), standardtag = Val{ST}(), concrete_jac = CJ === nothing ? CJ : Val{CJ}(), - linsolve = linsolve, kwargs...) end diff --git a/test/interface/linear_solver_split_ode_test.jl b/test/interface/linear_solver_split_ode_test.jl index a920ba3341..f4e7403446 100644 --- a/test/interface/linear_solver_split_ode_test.jl +++ b/test/interface/linear_solver_split_ode_test.jl @@ -5,15 +5,15 @@ using LinearAlgebra, LinearSolve import OrdinaryDiffEq.dolinsolve n = 8 -dt = 1 / 16 +dt = 1 / 1000 u0 = ones(n) tspan = (0.0, 1.0) M1 = 2ones(n) |> Diagonal #|> Array M2 = 2ones(n) |> Diagonal #|> Array -f1 = M1 |> MatrixOperator -f2 = M2 |> MatrixOperator +f1 = (du,u,p,t) -> du .= M1 * u +f2 = (du,u,p,t) -> du .= M2 * u prob = SplitODEProblem(f1, f2, u0, tspan) for algname in (:SBDF2, @@ -21,72 +21,32 @@ for algname in (:SBDF2, :KenCarp47) @testset "$algname" begin alg0 = @eval $algname() - alg1 = @eval $algname(linsolve = GenericFactorization()) + alg1 = @eval $algname(linsolve = LUFactorization()) kwargs = (dt = dt,) - # expected error message - msg = "Split ODE problem do not work with factorization linear solvers. Bug detailed in https://github.com/SciML/OrdinaryDiffEq.jl/pull/1643. Defaulting to linsolve=KrylovJL()" - @test_logs (:warn, msg) solve(prob, alg0; kwargs...) + solve(prob, alg0; kwargs...) @test DiffEqBase.__solve(prob, alg0; kwargs...).retcode == ReturnCode.Success - @test_broken DiffEqBase.__solve(prob, alg1; kwargs...).retcode == ReturnCode.Success + @test DiffEqBase.__solve(prob, alg1; kwargs...).retcode == ReturnCode.Success end end -##### -# deep dive -##### - -alg0 = KenCarp47() # passing case -alg1 = KenCarp47(linsolve = GenericFactorization()) # failing case - -## objects -ig0 = SciMLBase.init(prob, alg0; dt = dt) -ig1 = SciMLBase.init(prob, alg1; dt = dt) - -nl0 = ig0.cache.nlsolver -nl1 = ig1.cache.nlsolver - -lc0 = nl0.cache.linsolve -lc1 = nl1.cache.linsolve - -W0 = lc0.A -W1 = lc1.A - -# perform first step -OrdinaryDiffEq.loopheader!(ig0) -OrdinaryDiffEq.loopheader!(ig1) - -OrdinaryDiffEq.perform_step!(ig0, ig0.cache) -OrdinaryDiffEq.perform_step!(ig1, ig1.cache) - -@test !OrdinaryDiffEq.nlsolvefail(nl0) -@test OrdinaryDiffEq.nlsolvefail(nl1) - -# check operators -@test W0._concrete_form != W1._concrete_form -@test_broken W0._func_cache == W1._func_cache - -# check operator application -b = ones(n) -@test W0 * b == W1 * b -@test mul!(rand(n), W0, b) == mul!(rand(n), W1, b) -#@test W0 \ b == W1 \ b - -# check linear solve -lc0.b .= 1.0 -lc1.b .= 1.0 - -solve(lc0) -solve(lc1) +f1 = M1 |> MatrixOperator +f2 = M2 |> MatrixOperator +prob = SplitODEProblem(f1, f2, u0, tspan) -@test_broken lc0.u == lc1.u +for algname in (:SBDF2, + :SBDF3, + :KenCarp47) + @testset "$algname" begin + alg0 = @eval $algname() -# solve contried problem using OrdinaryDiffEq machinery -linres0 = dolinsolve(ig0, lc0; A = W0, b = b, linu = ones(n), reltol = 1e-8) -linres1 = dolinsolve(ig1, lc1; A = W1, b = b, linu = ones(n), reltol = 1e-8) + kwargs = (dt = dt,) -@test_broken linres0 == linres1 + solve(prob, alg0; kwargs...) + @test DiffEqBase.__solve(prob, alg0; kwargs...).retcode == ReturnCode.Success + end +end ### # custom linsolve function @@ -101,6 +61,4 @@ end alg = KenCarp47(linsolve = LinearSolveFunction(linsolve)) -@test solve(prob, alg).retcode == ReturnCode.Success - -nothing +@test solve(prob, alg).retcode == ReturnCode.Success \ No newline at end of file