Skip to content

Commit

Permalink
Initialize and finalize for discrete callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
BenChung committed Oct 29, 2024
1 parent e9ec43c commit eb14e35
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 25 deletions.
88 changes: 63 additions & 25 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,16 @@ struct SymbolicDiscreteCallback
# TODO: Iterative
condition::Any
affects::Any
initialize::Any
finalize::Any
reinitializealg::SciMLBase.DAEInitializationAlgorithm

function SymbolicDiscreteCallback(
condition, affects = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit())
condition, affects = NULL_AFFECT; reinitializealg = SciMLBase.CheckInit(),
initialize=NULL_AFFECT, finalize=NULL_AFFECT)
c = scalarize_condition(condition)
a = scalarize_affects(affects)
new(c, a, reinitializealg)
new(c, a, scalarize_affects(initialize), scalarize_affects(finalize), reinitializealg)
end # Default affect to nothing
end

Expand Down Expand Up @@ -373,11 +376,16 @@ function Base.show(io::IO, db::SymbolicDiscreteCallback)
end

function Base.:(==)(e1::SymbolicDiscreteCallback, e2::SymbolicDiscreteCallback)
isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects)
isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) &&
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)
end
function Base.hash(cb::SymbolicDiscreteCallback, s::UInt)
s = hash(cb.condition, s)
cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s)
s = cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s)
s = cb.initialize isa AbstractVector ? foldr(hash, cb.initialize, init = s) : hash(cb.initialize, s)
s = cb.finalize isa AbstractVector ? foldr(hash, cb.finalize, init = s) : hash(cb.finalize, s)
s = hash(cb.reinitializealg, s)
return s
end

condition(cb::SymbolicDiscreteCallback) = cb.condition
Expand All @@ -397,10 +405,23 @@ function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback})
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
end


initialize_affects(cb::SymbolicDiscreteCallback) = cb.initialize
function initialize_affects(cbs::Vector{SymbolicDiscreteCallback})
mapreduce(initialize_affects, vcat, cbs, init = Equation[])
end

finalize_affects(cb::SymbolicDiscreteCallback) = cb.finalize
function finalize_affects(cbs::Vector{SymbolicDiscreteCallback})
mapreduce(finalize_affects, vcat, cbs, init = Equation[])
end

function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
af = affects(cb)
af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
SymbolicDiscreteCallback(namespace_condition(condition(cb), s), af)
function namespace_affects(af)
return af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
end
SymbolicDiscreteCallback(namespace_condition(condition(cb), s), namespace_affects(affects(cb)),
reinitializealg=cb.reinitializealg, initialize=namespace_affects(initialize_affects(cb)), finalize=namespace_affects(finalize_affects(cb)))
end

SymbolicDiscreteCallbacks(cb::Pair) = SymbolicDiscreteCallback[SymbolicDiscreteCallback(cb)]
Expand Down Expand Up @@ -773,10 +794,10 @@ function generate_vector_rootfinding_callback(
let save_idxs = save_idxs
if !isnothing(fn.initialize)
(i) -> begin
fn.initialize(i)
for idx in save_idxs
SciMLBase.save_discretes!(i, idx)
end
fn.initialize(i)
end
else
(i) -> begin
Expand Down Expand Up @@ -809,20 +830,13 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
eq_aff = affects(cb)
eq_neg_aff = affect_negs(cb)
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
function compile_optional_affect(aff, default = nothing)
if isnothing(aff) || aff == default
return nothing
else
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
end
end
if eq_neg_aff === eq_aff
affect_neg = affect
else
affect_neg = compile_optional_affect(eq_neg_aff)
affect_neg = _compile_optional_affect(NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs...)
end
initialize = compile_optional_affect(initialize_affects(cb), NULL_AFFECT)
finalize = compile_optional_affect(finalize_affects(cb), NULL_AFFECT)
initialize = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
finalize = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
end

Expand Down Expand Up @@ -914,31 +928,48 @@ function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
end


function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...)
if isnothing(aff) || aff == default
return nothing
else
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
end
end
function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
kwargs...)
cond = condition(cb)
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
postprocess_affect_expr!, kwargs...)

user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
initfn = let save_idxs = save_idxs
initfn = let
save_idxs = save_idxs
initfun=user_initfun
function (cb, u, t, integrator)
if !isnothing(initfun)
initfun(integrator)
end
for idx in save_idxs
SciMLBase.save_discretes!(integrator, idx)
end
end
end
else
initfn = SciMLBase.INITIALIZE_DEFAULT
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i)
end
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i)
if cond isa AbstractVector
# Preset Time
return PresetTimeCallback(
cond, as; initialize = initfn, initializealg = reinitialization_alg(cb))
cond, as; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb))
else
# Periodic
return PeriodicCallback(
as, cond; initialize = initfn, initializealg = reinitialization_alg(cb))
as, cond; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb))
end
end

Expand All @@ -951,20 +982,27 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...)
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
postprocess_affect_expr!, kwargs...)

user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
initfn = let save_idxs = save_idxs
initfn = let save_idxs = save_idxs, initfun = user_initfun
function (cb, u, t, integrator)
if !isnothing(initfun)
initfun(integrator)
end
for idx in save_idxs
SciMLBase.save_discretes!(integrator, idx)
end
end
end
else
initfn = SciMLBase.INITIALIZE_DEFAULT
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i)
end
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i)
return DiscreteCallback(
c, as; initialize = initfn, initializealg = reinitialization_alg(cb))
c, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg(cb))
end
end

Expand Down
48 changes: 48 additions & 0 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,54 @@ end
@test seen == true
@test inited == true
@test finaled == true

#periodic
inited = false
finaled = false
cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, [x ~ 2], initialize=a, finalize=b)
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
sol = solve(prob, Tsit5())
@test inited == true
@test finaled == true
@test isapprox(sol[x][3], 0.0, atol=1e-9)
@test sol[x][4] 2.0
@test sol[x][5] 1.0


seen = false
inited = false
finaled = false
cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, f, initialize=a, finalize=b)
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
sol = solve(prob, Tsit5())
@test seen == true
@test inited == true

#preset
seen = false
inited = false
finaled = false
cb3 = ModelingToolkit.SymbolicDiscreteCallback([1.0], f, initialize=a, finalize=b)
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
sol = solve(prob, Tsit5())
@test seen == true
@test inited == true
@test finaled == true

#equational
seen = false
inited = false
finaled = false
cb3 = ModelingToolkit.SymbolicDiscreteCallback(t == 1.0, f, initialize=a, finalize=b)
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
sol = solve(prob, Tsit5(); tstops=1.0)
@test seen == true
@test inited == true
@test finaled == true
end

@testset "Bump" begin
Expand Down

0 comments on commit eb14e35

Please sign in to comment.