Skip to content

Commit

Permalink
Remove early defaulting and fix factorization algs
Browse files Browse the repository at this point in the history
No longer default to GMRES on SplitODEProblem when it's not an operator equation. LUFactorization works fine. GenericFactorization has stronger assumptions than is required.

This removes the early defaulting since `linsolve=nothing` is type-inferrable for defaultalg since it's a single algorithm, and thus doing it early has no benefit but significant drawbacks. One drawback case is Radau since it picks the default based on real numbers but requires complex numbers. This is also the reason for the aforementioned factorization issue on SplitODEProblem, it was simply choosing wrong.
  • Loading branch information
ChrisRackauckas committed Nov 5, 2023
1 parent 8f5474a commit a8c5f45
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 109 deletions.
47 changes: 3 additions & 44 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,50 +234,18 @@ function DiffEqBase.prepare_alg(alg::Union{
OrdinaryDiffEqExponentialAlgorithm{0, AD, FDT}},
u0::AbstractArray{T},
p, prob) where {AD, FDT, T}
if alg isa OrdinaryDiffEqExponentialAlgorithm
linsolve = nothing
elseif alg.linsolve === nothing
if (prob.f isa ODEFunction && prob.f.f isa AbstractSciMLOperator)
linsolve = LinearSolve.defaultalg(prob.f.f, u0)
elseif (prob.f isa SplitFunction &&
prob.f.f1.f isa AbstractSciMLOperator)
linsolve = LinearSolve.defaultalg(prob.f.f1.f, u0)
if (linsolve === nothing) || (linsolve isa LinearSolve.DefaultLinearSolver &&
linsolve.alg !== LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
msg = "Split ODE problem do not work with factorization linear solvers. Bug detailed in https://github.com/SciML/OrdinaryDiffEq.jl/pull/1643. Defaulting to linsolve=KrylovJL()"
@warn msg
linsolve = KrylovJL_GMRES()
end
elseif (prob isa ODEProblem || prob isa DDEProblem) &&
(prob.f.mass_matrix === nothing ||
(prob.f.mass_matrix !== nothing &&
!(prob.f.jac_prototype isa AbstractSciMLOperator)))
linsolve = LinearSolve.defaultalg(prob.f.jac_prototype, u0)
else
# If mm is a sparse matrix and A is a MatrixOperator, then let linear
# solver choose things later
linsolve = nothing
end
else
linsolve = alg.linsolve
end

# If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already)
# don't use a large chunksize as it will either error or not be beneficial
if !(alg_autodiff(alg) isa AutoForwardDiff) ||
(isbitstype(T) && sizeof(T) > 24) ||
(prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if alg isa OrdinaryDiffEqExponentialAlgorithm
return remake(alg, chunk_size = Val{1}())
else
return remake(alg, chunk_size = Val{1}(), linsolve = linsolve)
end
return remake(alg, chunk_size = Val{1}())
end

L = StaticArrayInterface.known_length(typeof(u0))
if L === nothing # dynamic sized

# If chunksize is zero, pick chunksize right at the start of solve and
# then do function barrier to infer the full solve
x = if prob.f.colorvec === nothing
Expand All @@ -287,19 +255,10 @@ function DiffEqBase.prepare_alg(alg::Union{
end

cs = ForwardDiff.pickchunksize(x)

if alg isa OrdinaryDiffEqExponentialAlgorithm
return remake(alg, chunk_size = Val{cs}())
else
return remake(alg, chunk_size = Val{cs}(), linsolve = linsolve)
end
return remake(alg, chunk_size = Val{cs}())
else # statically sized
cs = pick_static_chunksize(Val{L}())
if alg isa OrdinaryDiffEqExponentialAlgorithm
return remake(alg, chunk_size = cs)
else
return remake(alg, chunk_size = cs, linsolve = linsolve)
end
return remake(alg, chunk_size = cs)
end
end

Expand Down
4 changes: 1 addition & 3 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,11 @@ function DiffEqBase.remake(thing::Union{
ST, CJ},
OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ
},
DAEAlgorithm{CS, AD, FDT, ST, CJ}};
linsolve, kwargs...) where {CS, AD, FDT, ST, CJ}
DAEAlgorithm{CS, AD, FDT, ST, CJ}}; kwargs...) where {CS, AD, FDT, ST, CJ}
T = SciMLBase.remaker_of(thing)
T(; SciMLBase.struct_as_namedtuple(thing)...,
chunk_size = Val{CS}(), autodiff = Val{AD}(), standardtag = Val{ST}(),
concrete_jac = CJ === nothing ? CJ : Val{CJ}(),
linsolve = linsolve,
kwargs...)
end

Expand Down
82 changes: 20 additions & 62 deletions test/interface/linear_solver_split_ode_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,48 @@ using LinearAlgebra, LinearSolve
import OrdinaryDiffEq.dolinsolve

n = 8
dt = 1 / 16
dt = 1 / 1000
u0 = ones(n)
tspan = (0.0, 1.0)

M1 = 2ones(n) |> Diagonal #|> Array
M2 = 2ones(n) |> Diagonal #|> Array

f1 = M1 |> MatrixOperator
f2 = M2 |> MatrixOperator
f1 = (du,u,p,t) -> du .= M1 * u
f2 = (du,u,p,t) -> du .= M2 * u
prob = SplitODEProblem(f1, f2, u0, tspan)

for algname in (:SBDF2,
:SBDF3,
:KenCarp47)
@testset "$algname" begin
alg0 = @eval $algname()
alg1 = @eval $algname(linsolve = GenericFactorization())
alg1 = @eval $algname(linsolve = LUFactorization())

kwargs = (dt = dt,)

# expected error message
msg = "Split ODE problem do not work with factorization linear solvers. Bug detailed in https://github.com/SciML/OrdinaryDiffEq.jl/pull/1643. Defaulting to linsolve=KrylovJL()"
@test_logs (:warn, msg) solve(prob, alg0; kwargs...)
solve(prob, alg0; kwargs...)
@test DiffEqBase.__solve(prob, alg0; kwargs...).retcode == ReturnCode.Success
@test_broken DiffEqBase.__solve(prob, alg1; kwargs...).retcode == ReturnCode.Success
@test DiffEqBase.__solve(prob, alg1; kwargs...).retcode == ReturnCode.Success
end
end

#####
# deep dive
#####

alg0 = KenCarp47() # passing case
alg1 = KenCarp47(linsolve = GenericFactorization()) # failing case

## objects
ig0 = SciMLBase.init(prob, alg0; dt = dt)
ig1 = SciMLBase.init(prob, alg1; dt = dt)

nl0 = ig0.cache.nlsolver
nl1 = ig1.cache.nlsolver

lc0 = nl0.cache.linsolve
lc1 = nl1.cache.linsolve

W0 = lc0.A
W1 = lc1.A

# perform first step
OrdinaryDiffEq.loopheader!(ig0)
OrdinaryDiffEq.loopheader!(ig1)

OrdinaryDiffEq.perform_step!(ig0, ig0.cache)
OrdinaryDiffEq.perform_step!(ig1, ig1.cache)

@test !OrdinaryDiffEq.nlsolvefail(nl0)
@test OrdinaryDiffEq.nlsolvefail(nl1)

# check operators
@test W0._concrete_form != W1._concrete_form
@test_broken W0._func_cache == W1._func_cache

# check operator application
b = ones(n)
@test W0 * b == W1 * b
@test mul!(rand(n), W0, b) == mul!(rand(n), W1, b)
#@test W0 \ b == W1 \ b

# check linear solve
lc0.b .= 1.0
lc1.b .= 1.0

solve(lc0)
solve(lc1)
f1 = M1 |> MatrixOperator
f2 = M2 |> MatrixOperator
prob = SplitODEProblem(f1, f2, u0, tspan)

@test_broken lc0.u == lc1.u
for algname in (:SBDF2,
:SBDF3,
:KenCarp47)
@testset "$algname" begin
alg0 = @eval $algname()

# solve contried problem using OrdinaryDiffEq machinery
linres0 = dolinsolve(ig0, lc0; A = W0, b = b, linu = ones(n), reltol = 1e-8)
linres1 = dolinsolve(ig1, lc1; A = W1, b = b, linu = ones(n), reltol = 1e-8)
kwargs = (dt = dt,)

@test_broken linres0 == linres1
solve(prob, alg0; kwargs...)
@test DiffEqBase.__solve(prob, alg0; kwargs...).retcode == ReturnCode.Success
end
end

###
# custom linsolve function
Expand All @@ -101,6 +61,4 @@ end

alg = KenCarp47(linsolve = LinearSolveFunction(linsolve))

@test solve(prob, alg).retcode == ReturnCode.Success

nothing
@test solve(prob, alg).retcode == ReturnCode.Success

0 comments on commit a8c5f45

Please sign in to comment.