diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c427ae02a1..0b758a6e86 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -111,9 +111,14 @@ DAEs will be reinitialized using `reinitializealg` (which defaults to `SciMLBase This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of initialization. + +Initial and final affects can also be specified with SCC, which are specified identically to positive and negative edge affects. Initialization affects +will run as soon as the solver starts, while finalization affects will be executed after termination. """ struct SymbolicContinuousCallback eqs::Vector{Equation} + initialize::Union{Vector{Equation}, FunctionalAffect} + finalize::Union{Vector{Equation}, FunctionalAffect} affect::Union{Vector{Equation}, FunctionalAffect} affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing} rootfind::SciMLBase.RootfindOpt @@ -122,9 +127,12 @@ struct SymbolicContinuousCallback eqs::Vector{Equation}, affect = NULL_AFFECT, affect_neg = affect, + initialize = NULL_AFFECT, + finalize = NULL_AFFECT, rootfind = SciMLBase.LeftRootFind, reinitializealg = SciMLBase.CheckInit()) - new(eqs, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg) + new(eqs, initialize, finalize, make_affect(affect), + make_affect(affect_neg), rootfind, reinitializealg) end # Default affect to nothing end make_affect(affect) = affect @@ -133,17 +141,80 @@ make_affect(affect::NamedTuple) = FunctionalAffect(; affect...) function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && + isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) && isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind) end Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs) function Base.hash(cb::SymbolicContinuousCallback, s::UInt) + hash_affect(affect::AbstractVector, s) = foldr(hash, affect, init = s) + hash_affect(affect, s) = hash(affect, s) s = foldr(hash, cb.eqs, init = s) - s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s) - s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) : - hash(cb.affect_neg, s) + s = hash_affect(cb.affect, s) + s = hash_affect(cb.affect_neg, s) + s = hash_affect(cb.initialize, s) + s = hash_affect(cb.finalize, s) + s = hash(cb.reinitializealg, s) hash(cb.rootfind, s) end +function Base.show(io::IO, cb::SymbolicContinuousCallback) + indent = get(io, :indent, 0) + iio = IOContext(io, :indent => indent + 1) + print(io, "SymbolicContinuousCallback(") + print(iio, "Equations:") + show(iio, equations(cb)) + print(iio, "; ") + if affects(cb) != NULL_AFFECT + print(iio, "Affect:") + show(iio, affects(cb)) + print(iio, ", ") + end + if affect_negs(cb) != NULL_AFFECT + print(iio, "Negative-edge affect:") + show(iio, affect_negs(cb)) + print(iio, ", ") + end + if initialize_affects(cb) != NULL_AFFECT + print(iio, "Initialization affect:") + show(iio, initialize_affects(cb)) + print(iio, ", ") + end + if finalize_affects(cb) != NULL_AFFECT + print(iio, "Finalization affect:") + show(iio, finalize_affects(cb)) + end + print(iio, ")") +end + +function Base.show(io::IO, mime::MIME"text/plain", cb::SymbolicContinuousCallback) + indent = get(io, :indent, 0) + iio = IOContext(io, :indent => indent + 1) + println(io, "SymbolicContinuousCallback:") + println(iio, "Equations:") + show(iio, mime, equations(cb)) + print(iio, "\n") + if affects(cb) != NULL_AFFECT + println(iio, "Affect:") + show(iio, mime, affects(cb)) + print(iio, "\n") + end + if affect_negs(cb) != NULL_AFFECT + println(iio, "Negative-edge affect:") + show(iio, mime, affect_negs(cb)) + print(iio, "\n") + end + if initialize_affects(cb) != NULL_AFFECT + println(iio, "Initialization affect:") + show(iio, mime, initialize_affects(cb)) + print(iio, "\n") + end + if finalize_affects(cb) != NULL_AFFECT + println(iio, "Finalization affect:") + show(iio, mime, finalize_affects(cb)) + print(iio, "\n") + end +end + to_equation_vector(eq::Equation) = [eq] to_equation_vector(eqs::Vector{Equation}) = eqs function to_equation_vector(eqs::Vector{Any}) @@ -157,14 +228,18 @@ end # wrap eq in vector SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2]) SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; + initialize = NULL_AFFECT, finalize = NULL_AFFECT, affect_neg = affect, rootfind = SciMLBase.LeftRootFind) SymbolicContinuousCallback( - eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind) + eqs = [eqs], affect = affect, affect_neg = affect_neg, + initialize = initialize, finalize = finalize, rootfind = rootfind) end function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; - affect_neg = affect, rootfind = SciMLBase.LeftRootFind) + affect_neg = affect, initialize = NULL_AFFECT, finalize = NULL_AFFECT, + rootfind = SciMLBase.LeftRootFind) SymbolicContinuousCallback( - eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind) + eqs = eqs, affect = affect, affect_neg = affect_neg, + initialize = initialize, finalize = finalize, rootfind = rootfind) end SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb] @@ -199,15 +274,28 @@ function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) end +initialize_affects(cb::SymbolicContinuousCallback) = cb.initialize +function initialize_affects(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(initialize_affects, vcat, cbs, init = Equation[]) +end + +finalize_affects(cb::SymbolicContinuousCallback) = cb.finalize +function finalize_affects(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(finalize_affects, vcat, cbs, init = Equation[]) +end + namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) namespace_affects(::Nothing, s) = nothing function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback - SymbolicContinuousCallback( - namespace_equation.(equations(cb), (s,)), - namespace_affects(affects(cb), s); - affect_neg = namespace_affects(affect_negs(cb), s)) + SymbolicContinuousCallback(; + eqs = namespace_equation.(equations(cb), (s,)), + affect = namespace_affects(affects(cb), s), + affect_neg = namespace_affects(affect_negs(cb), s), + initialize = namespace_affects(initialize_affects(cb), s), + finalize = namespace_affects(finalize_affects(cb), s), + rootfind = cb.rootfind) end """ @@ -241,13 +329,17 @@ struct SymbolicDiscreteCallback # TODO: Iterative condition::Any affects::Any + initialize::Any + finalize::Any reinitializealg::SciMLBase.DAEInitializationAlgorithm function SymbolicDiscreteCallback( - condition, affects = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit()) + condition, affects = NULL_AFFECT; reinitializealg = SciMLBase.CheckInit(), + initialize = NULL_AFFECT, finalize = NULL_AFFECT) c = scalarize_condition(condition) a = scalarize_affects(affects) - new(c, a, reinitializealg) + new(c, a, scalarize_affects(initialize), + scalarize_affects(finalize), reinitializealg) end # Default affect to nothing end @@ -286,11 +378,19 @@ function Base.show(io::IO, db::SymbolicDiscreteCallback) end function Base.:(==)(e1::SymbolicDiscreteCallback, e2::SymbolicDiscreteCallback) - isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) + isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) && + isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) end function Base.hash(cb::SymbolicDiscreteCallback, s::UInt) s = hash(cb.condition, s) - cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s) + s = cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : + hash(cb.affects, s) + s = cb.initialize isa AbstractVector ? foldr(hash, cb.initialize, init = s) : + hash(cb.initialize, s) + s = cb.finalize isa AbstractVector ? foldr(hash, cb.finalize, init = s) : + hash(cb.finalize, s) + s = hash(cb.reinitializealg, s) + return s end condition(cb::SymbolicDiscreteCallback) = cb.condition @@ -310,10 +410,25 @@ function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) end +initialize_affects(cb::SymbolicDiscreteCallback) = cb.initialize +function initialize_affects(cbs::Vector{SymbolicDiscreteCallback}) + mapreduce(initialize_affects, vcat, cbs, init = Equation[]) +end + +finalize_affects(cb::SymbolicDiscreteCallback) = cb.finalize +function finalize_affects(cbs::Vector{SymbolicDiscreteCallback}) + mapreduce(finalize_affects, vcat, cbs, init = Equation[]) +end + function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback - af = affects(cb) - af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s) - SymbolicDiscreteCallback(namespace_condition(condition(cb), s), af) + function namespace_affects(af) + return af isa AbstractVector ? namespace_affect.(af, Ref(s)) : + namespace_affect(af, s) + end + SymbolicDiscreteCallback( + namespace_condition(condition(cb), s), namespace_affects(affects(cb)), + reinitializealg = cb.reinitializealg, initialize = namespace_affects(initialize_affects(cb)), + finalize = namespace_affects(finalize_affects(cb))) end SymbolicDiscreteCallbacks(cb::Pair) = SymbolicDiscreteCallback[SymbolicDiscreteCallback(cb)] @@ -589,22 +704,26 @@ function generate_single_rootfinding_callback( rf_oop(u, parameter_values(integ), t) end end - + user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : + (c, u, t, i) -> affect_function.initialize(i) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing initfn = let save_idxs = save_idxs function (cb, u, t, integrator) + user_initfun(cb, u, t, integrator) for idx in save_idxs SciMLBase.save_discretes!(integrator, idx) end end end else - initfn = SciMLBase.INITIALIZE_DEFAULT + initfn = user_initfun end return ContinuousCallback( cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, initialize = initfn, + finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : + (c, u, t, i) -> affect_function.finalize(i), initializealg = reinitialization_alg(cb)) end @@ -626,13 +745,14 @@ function generate_vector_rootfinding_callback( _, rf_ip = generate_custom_function( sys, rhss, dvs, ps; expression = Val{false}, kwargs...) - affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn( - cb, - sys, - dvs, - ps, - kwargs) - for cb in cbs] + affect_functions = @NamedTuple{ + affect::Function, + affect_neg::Union{Function, Nothing}, + initialize::Union{Function, Nothing}, + finalize::Union{Function, Nothing}}[ + compile_affect_fn(cb, sys, dvs, ps, kwargs) + for cb in cbs] + cond = function (out, u, t, integ) rf_ip(out, u, parameter_values(integ), t) end @@ -658,26 +778,63 @@ function generate_vector_rootfinding_callback( affect_neg(integ) end end - if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - save_idxs = mapreduce( - cb -> get(ic.callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[]) - initfn = if isempty(save_idxs) - SciMLBase.INITIALIZE_DEFAULT + function handle_optional_setup_fn(funs, default) + if all(isnothing, funs) + return default else - let save_idxs = save_idxs - function (cb, u, t, integrator) - for idx in save_idxs - SciMLBase.save_discretes!(integrator, idx) + return let funs = funs + function (cb, u, t, integ) + for func in funs + if isnothing(func) + continue + else + func(integ) + end end end end end + end + initialize = nothing + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + initialize = handle_optional_setup_fn( + map( + (cb, fn) -> begin + if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing + let save_idxs = save_idxs + if !isnothing(fn.initialize) + (i) -> begin + fn.initialize(i) + for idx in save_idxs + SciMLBase.save_discretes!(i, idx) + end + end + else + (i) -> begin + for idx in save_idxs + SciMLBase.save_discretes!(i, idx) + end + end + end + end + else + fn.initialize + end + end, + cbs, + affect_functions), + SciMLBase.INITIALIZE_DEFAULT) + else - initfn = SciMLBase.INITIALIZE_DEFAULT + initialize = handle_optional_setup_fn( + map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT) end + + finalize = handle_optional_setup_fn( + map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT) return VectorContinuousCallback( cond, affect, affect_neg, length(eqs), rootfind = rootfind, - initialize = initfn, initializealg = reinitialization) + initialize = initialize, finalize = finalize, initializealg = reinitialization) end """ @@ -689,13 +846,15 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) if eq_neg_aff === eq_aff affect_neg = affect - elseif isnothing(eq_neg_aff) - affect_neg = nothing else - affect_neg = compile_affect( - eq_neg_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) + affect_neg = _compile_optional_affect( + NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs...) end - (affect = affect, affect_neg = affect_neg) + initialize = _compile_optional_affect( + NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) + finalize = _compile_optional_affect( + NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) + (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize) end function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys), @@ -786,31 +945,53 @@ function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...) compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) end +function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...) + if isnothing(aff) || aff == default + return nothing + else + return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) + end +end function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing, kwargs...) cond = condition(cb) as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false}, postprocess_affect_expr!, kwargs...) + + user_initfun = _compile_optional_affect( + NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) + user_finfun = _compile_optional_affect( + NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing - initfn = let save_idxs = save_idxs + initfn = let + save_idxs = save_idxs + initfun = user_initfun function (cb, u, t, integrator) + if !isnothing(initfun) + initfun(integrator) + end for idx in save_idxs SciMLBase.save_discretes!(integrator, idx) end end end else - initfn = SciMLBase.INITIALIZE_DEFAULT + initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : + (_, _, _, i) -> user_initfun(i) end + finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : + (_, _, _, i) -> user_finfun(i) if cond isa AbstractVector # Preset Time return PresetTimeCallback( - cond, as; initialize = initfn, initializealg = reinitialization_alg(cb)) + cond, as; initialize = initfn, finalize = finfun, + initializealg = reinitialization_alg(cb)) else # Periodic return PeriodicCallback( - as, cond; initialize = initfn, initializealg = reinitialization_alg(cb)) + as, cond; initialize = initfn, finalize = finfun, + initializealg = reinitialization_alg(cb)) end end @@ -823,20 +1004,32 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...) as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false}, postprocess_affect_expr!, kwargs...) + + user_initfun = _compile_optional_affect( + NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) + user_finfun = _compile_optional_affect( + NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing - initfn = let save_idxs = save_idxs + initfn = let save_idxs = save_idxs, initfun = user_initfun function (cb, u, t, integrator) + if !isnothing(initfun) + initfun(integrator) + end for idx in save_idxs SciMLBase.save_discretes!(integrator, idx) end end end else - initfn = SciMLBase.INITIALIZE_DEFAULT + initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : + (_, _, _, i) -> user_initfun(i) end + finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : + (_, _, _, i) -> user_finfun(i) return DiscreteCallback( - c, as; initialize = initfn, initializealg = reinitialization_alg(cb)) + c, as; initialize = initfn, finalize = finfun, + initializealg = reinitialization_alg(cb)) end end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 26593f980c..b58d5911f4 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -970,6 +970,93 @@ end @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end +@testset "Initialization" begin + @variables x(t) + seen = false + f = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> seen = true, sts = [], pars = [], discretes = []) + cb1 = ModelingToolkit.SymbolicContinuousCallback( + [x ~ 0], Equation[], initialize = [x ~ 1.5], finalize = f) + @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; continuous_events = [cb1]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) + sol = solve(prob, Tsit5(); dtmax = 0.01) + @test sol[x][1] ≈ 1.0 + @test sol[x][2] ≈ 1.5 # the initialize affect has been applied + @test seen == true + + @variables x(t) + seen = false + f = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> seen = true, sts = [], pars = [], discretes = []) + cb1 = ModelingToolkit.SymbolicContinuousCallback( + [x ~ 0], Equation[], initialize = [x ~ 1.5], finalize = f) + inited = false + finaled = false + a = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> inited = true, sts = [], pars = [], discretes = []) + b = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> finaled = true, sts = [], pars = [], discretes = []) + cb2 = ModelingToolkit.SymbolicContinuousCallback( + [x ~ 0.1], Equation[], initialize = a, finalize = b) + @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; continuous_events = [cb1, cb2]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) + sol = solve(prob, Tsit5()) + @test sol[x][1] ≈ 1.0 + @test sol[x][2] ≈ 1.5 # the initialize affect has been applied + @test seen == true + @test inited == true + @test finaled == true + + #periodic + inited = false + finaled = false + cb3 = ModelingToolkit.SymbolicDiscreteCallback( + 1.0, [x ~ 2], initialize = a, finalize = b) + @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) + sol = solve(prob, Tsit5()) + @test inited == true + @test finaled == true + @test isapprox(sol[x][3], 0.0, atol = 1e-9) + @test sol[x][4] ≈ 2.0 + @test sol[x][5] ≈ 1.0 + + seen = false + inited = false + finaled = false + cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, f, initialize = a, finalize = b) + @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) + sol = solve(prob, Tsit5()) + @test seen == true + @test inited == true + + #preset + seen = false + inited = false + finaled = false + cb3 = ModelingToolkit.SymbolicDiscreteCallback([1.0], f, initialize = a, finalize = b) + @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) + sol = solve(prob, Tsit5()) + @test seen == true + @test inited == true + @test finaled == true + + #equational + seen = false + inited = false + finaled = false + cb3 = ModelingToolkit.SymbolicDiscreteCallback( + t == 1.0, f, initialize = a, finalize = b) + @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) + sol = solve(prob, Tsit5(); tstops = 1.0) + @test seen == true + @test inited == true + @test finaled == true +end + @testset "Bump" begin @variables x(t) [irreducible = true] y(t) [irreducible = true] eqs = [x ~ y, D(x) ~ -1]