Skip to content

Commit bf9eab5

Browse files
Merge branch 'master' into uprev
2 parents 3d9cb3e + 4cc0380 commit bf9eab5

File tree

6 files changed

+39
-10
lines changed

6 files changed

+39
-10
lines changed

src/alg_utils.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,11 @@ function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
285285
end
286286

287287
has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false
288-
has_autodiff(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm, OrdinaryDiffEqImplicitAlgorithm,
289-
CompositeAlgorithm, OrdinaryDiffEqExponentialAlgorithm, DAEAlgorithm}) = true
288+
function has_autodiff(alg::Union{
289+
OrdinaryDiffEqAdaptiveImplicitAlgorithm, OrdinaryDiffEqImplicitAlgorithm,
290+
CompositeAlgorithm, OrdinaryDiffEqExponentialAlgorithm, DAEAlgorithm})
291+
true
292+
end
290293

291294
# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
292295
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)

src/caches/bdf_caches.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ function alg_cache(alg::ABDF2, u, rate_prototype, ::Type{uEltypeNoUnits},
4747
fsalfirstprev = zero(rate_prototype)
4848
atmp = similar(u, uEltypeNoUnits)
4949
recursivefill!(atmp, false)
50+
algebraic_vars = f.mass_matrix === I ? nothing :
51+
[all(iszero, x) for x in eachcol(f.mass_matrix)]
5052

51-
eulercache = ImplicitEulerCache(u, uprev, uprev2, fsalfirst, atmp, nlsolver)
53+
eulercache = ImplicitEulerCache(
54+
u, uprev, uprev2, fsalfirst, atmp, nlsolver, algebraic_vars)
5255

5356
dtₙ₋₁ = one(dt)
5457
zₙ₋₁ = zero(u)

src/caches/sdirk_caches.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
abstract type SDIRKMutableCache <: OrdinaryDiffEqMutableCache end
22

3-
@cache mutable struct ImplicitEulerCache{uType, rateType, uNoUnitsType, N} <:
3+
@cache mutable struct ImplicitEulerCache{uType, rateType, uNoUnitsType, N, AV} <:
44
SDIRKMutableCache
55
u::uType
66
uprev::uType
77
uprev2::uType
88
fsalfirst::rateType
99
atmp::uNoUnitsType
1010
nlsolver::N
11+
algebraic_vars::AV
1112
end
1213

1314
function alg_cache(alg::ImplicitEuler, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -22,7 +23,10 @@ function alg_cache(alg::ImplicitEuler, u, rate_prototype, ::Type{uEltypeNoUnits}
2223
atmp = similar(u, uEltypeNoUnits)
2324
recursivefill!(atmp, false)
2425

25-
ImplicitEulerCache(u, uprev, uprev2, fsalfirst, atmp, nlsolver)
26+
algebraic_vars = f.mass_matrix === I ? nothing :
27+
[all(iszero, x) for x in eachcol(f.mass_matrix)]
28+
29+
ImplicitEulerCache(u, uprev, uprev2, fsalfirst, atmp, nlsolver, algebraic_vars)
2630
end
2731

2832
mutable struct ImplicitEulerConstantCache{N} <: OrdinaryDiffEqConstantCache

src/initialize_dae.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ end
134134
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
135135
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
136136
initializeprob = prob.f.initializeprob
137-
137+
138138
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
139139
# Since then it's the case of not a DAE but has initializeprob
140140
# In which case, it should be differentiable
141-
isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff : true
141+
isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff :
142+
true
142143

143144
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
144145
nlsol = solve(initializeprob, alg)

src/perform_step/sdirk_perform_step.jl

+13
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ end
102102
end
103103

104104
integrator.fsallast = f(u, p, t + dt)
105+
106+
if integrator.opts.adaptive && integrator.f.mass_matrix !== I
107+
atmp = @. ifelse(!integrator.differential_vars, integrator.fsallast, false) ./
108+
integrator.opts.abstol
109+
integrator.EEst += integrator.opts.internalnorm(atmp, t)
110+
end
111+
105112
integrator.stats.nf += 1
106113
integrator.k[1] = integrator.fsalfirst
107114
integrator.k[2] = integrator.fsallast
@@ -152,6 +159,12 @@ end
152159
end
153160
integrator.stats.nf += 1
154161
f(integrator.fsallast, u, p, t + dt)
162+
163+
if integrator.opts.adaptive && integrator.f.mass_matrix !== I
164+
@.. broadcast=false atmp=ifelse(cache.algebraic_vars, integrator.fsallast, false) /
165+
integrator.opts.abstol
166+
integrator.EEst += integrator.opts.internalnorm(atmp, t)
167+
end
155168
end
156169

157170
@muladd function perform_step!(integrator, cache::ImplicitMidpointConstantCache,

test/interface/dae_initialize_integration.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,22 @@ sol = solve(prob, Rodas5(), initializealg = ShampineCollocationInit())
4444
# Initialize on ODEs
4545
# https://github.com/SciML/ModelingToolkit.jl/issues/2508
4646

47-
function testsys(du,u,p,t)
47+
function testsys(du, u, p, t)
4848
du[1] = -2
4949
end
50-
function initsys(du,u,p)
50+
function initsys(du, u, p)
5151
du[1] = -1 + u[1]
5252
end
5353
nlprob = NonlinearProblem(initsys, [0.0])
5454
initprobmap(nlprob) = nlprob.u
5555
sol = solve(nlprob)
5656

5757
_f = ODEFunction(testsys; initializeprob = nlprob, initializeprobmap = initprobmap)
58+
prob = ODEProblem(_f, [0.0], (0.0, 1.0))
59+
sol = solve(prob, Tsit5())
60+
@test SciMLBase.successful_retcode(sol)
61+
@test sol[1] == [1.0]
62+
5863
prob = ODEProblem(_f, [0.0], (0.0,1.0))
5964
sol = solve(prob, Tsit5(), dt = 1e-10)
6065
@test SciMLBase.successful_retcode(sol)
@@ -66,4 +71,4 @@ sol = solve(prob, Rodas5P(), dt = 1e-10)
6671
@test SciMLBase.successful_retcode(sol)
6772
@test sol[1] == [1.0]
6873
@test sol[2] [0.9999999998]
69-
@test sol[end] [-1.0]
74+
@test sol[end] [-1.0]

0 commit comments

Comments
 (0)