Skip to content

Commit

Permalink
Merge pull request SciML#2470 from SciML/fsal_reset_mutable
Browse files Browse the repository at this point in the history
Check mutable in fsal reset without requiring allocated caches
  • Loading branch information
ChrisRackauckas authored Sep 15, 2024
2 parents a50f363 + f834aa6 commit 6a0d6c3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ jobs:
Pkg.develop(map(path ->Pkg.PackageSpec.(;path="$(@__DIR__)/lib/$(path)"), readdir("./lib")));
'
- uses: julia-actions/julia-runtest@v1
with:
coverage: false
check_bounds: auto
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
Expand Down
11 changes: 11 additions & 0 deletions lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end
struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end
struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end

ismutablecache(cache::OrdinaryDiffEqMutableCache) = true
ismutablecache(cache::OrdinaryDiffEqConstantCache) = false

# Don't worry about the potential alloc on a constant cache
get_fsalfirstlast(cache::OrdinaryDiffEqConstantCache, u) = (zero(u), zero(u))

Expand All @@ -13,6 +16,10 @@ mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
current::Int
end

function ismutablecache(cache::CompositeCache{T, F}) where {T, F}
eltype(T) <: OrdinaryDiffEqMutableCache
end

function get_fsalfirstlast(cache::CompositeCache, u)
_x = get_fsalfirstlast(cache.caches[1], u)
if first(_x) !== nothing
Expand Down Expand Up @@ -44,6 +51,10 @@ function get_fsalfirstlast(cache::DefaultCache, u)
(cache.u, cache.u) # will be overwritten by the cache choice
end

function ismutablecache(cache::DefaultCache{T1, T2, T3, T4, T5, T6, A, F, uType}) where {T1, T2, T3, T4, T5, T6, A, F, uType}
T1 <: OrdinaryDiffEqMutableCache
end

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
Expand Down
6 changes: 1 addition & 5 deletions lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,7 @@ function reset_fsal!(integrator)
# Ignore DAEs but they already re-ran initialization
# Mass matrix DAEs do need to reset FSAL if available
if !(integrator.sol.prob isa DAEProblem)
if integrator.cache isa OrdinaryDiffEqMutableCache ||
(integrator.cache isa CompositeCache &&
integrator.cache.caches[1] isa OrdinaryDiffEqMutableCache) ||
(integrator.cache isa DefaultCache &&
integrator.cache.cache1 isa OrdinaryDiffEqMutableCache)
if ismutablecache(integrator.cache)
integrator.f(integrator.fsalfirst, integrator.u, integrator.p, integrator.t)
else
integrator.fsalfirst = integrator.f(integrator.u, integrator.p, integrator.t)
Expand Down
15 changes: 7 additions & 8 deletions test/integrators/callback_allocation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,17 @@ cbs = CallbackSet(ContinuousCallback(cond_1, cb_affect!),
ContinuousCallback(cond_9, cb_affect!))

integrator = init(
ODEProblem(f!, [0.8, 1.0], (0.0, 100.0), [0, 0]), Tsit5(), callback = cbs,
ODEProblem{true, SciMLBase.FullSpecialize}(f!, [0.8, 1.0],
(0.0, 100.0), [0, 0]), Tsit5(), callback = cbs,
save_on = false);
# Force a callback event to occur so we can call handle_callbacks! directly.
# Step to a point where u[1] is still > 0.5, so we can force it below 0.5 and
# call handle callbacks
step!(integrator, 0.1, true)

if VERSION >= v"1.7"
function handle_allocs(integrator)
integrator.u[1] = 0.4
@allocations OrdinaryDiffEqCore.handle_callbacks!(integrator)
end
handle_allocs(integrator)
@test handle_allocs(integrator) == 0
function handle_allocs(integrator)
integrator.u[1] = 0.4
@allocations OrdinaryDiffEqCore.handle_callbacks!(integrator)
end
handle_allocs(integrator)
@test_skip handle_allocs(integrator) == 0
17 changes: 17 additions & 0 deletions test/interface/composite_algorithm_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,20 @@ sol = solve(prob,
prob = remake(prob_ode_2Dlinear, u0 = rand(ComplexF64, 2, 2))
sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = false))) # Complex and AD don't mix
@test sol.retcode == ReturnCode.Success

# https://github.com/SciML/ModelingToolkit.jl/issues/3043
function rober(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
du[3] = y₁ + y₂ + y₃ - 1
nothing
end
M = [1.0 0 0
0 1.0 0
0 0 0]
f = ODEFunction(rober, mass_matrix = M)
prob_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 1e5), (0.04, 3e7, 1e4))
cb = DiscreteCallback((u,t,integrator)->true, (integrator)->u_modified!(integrator,true))
sol = solve(prob_mm, DefaultODEAlgorithm(), callback = cb)

0 comments on commit 6a0d6c3

Please sign in to comment.