diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 18497a0d1e..0b758a6e86 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -131,7 +131,7 @@ struct SymbolicContinuousCallback finalize = NULL_AFFECT, rootfind = SciMLBase.LeftRootFind, reinitializealg = SciMLBase.CheckInit()) - new(eqs, initialize, finalize, make_affect(affect), + new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg) end # Default affect to nothing end @@ -227,18 +227,19 @@ function SymbolicContinuousCallback(args...) end # wrap eq in vector SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2]) SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough -function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; - initialize=NULL_AFFECT, finalize=NULL_AFFECT, - affect_neg = affect, rootfind = SciMLBase.LeftRootFind) +function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; + initialize = NULL_AFFECT, finalize = NULL_AFFECT, + affect_neg = affect, rootfind = SciMLBase.LeftRootFind) SymbolicContinuousCallback( - eqs = [eqs], affect = affect, affect_neg = affect_neg, - initialize=initialize, finalize=finalize, rootfind = rootfind) + eqs = [eqs], affect = affect, affect_neg = affect_neg, + initialize = initialize, finalize = finalize, rootfind = rootfind) end function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; - affect_neg = affect, initialize=NULL_AFFECT, finalize=NULL_AFFECT, + affect_neg = affect, initialize = NULL_AFFECT, finalize = NULL_AFFECT, rootfind = SciMLBase.LeftRootFind) SymbolicContinuousCallback( - eqs = eqs, affect = affect, affect_neg = affect_neg, initialize=initialize, finalize=finalize, rootfind = rootfind) + eqs = eqs, affect = affect, affect_neg = affect_neg, + initialize = initialize, finalize = finalize, rootfind = rootfind) end SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb] @@ -334,10 +335,11 @@ struct SymbolicDiscreteCallback function SymbolicDiscreteCallback( condition, affects = NULL_AFFECT; reinitializealg = SciMLBase.CheckInit(), - initialize=NULL_AFFECT, finalize=NULL_AFFECT) + initialize = NULL_AFFECT, finalize = NULL_AFFECT) c = scalarize_condition(condition) a = scalarize_affects(affects) - new(c, a, scalarize_affects(initialize), scalarize_affects(finalize), reinitializealg) + new(c, a, scalarize_affects(initialize), + scalarize_affects(finalize), reinitializealg) end # Default affect to nothing end @@ -376,14 +378,17 @@ function Base.show(io::IO, db::SymbolicDiscreteCallback) end function Base.:(==)(e1::SymbolicDiscreteCallback, e2::SymbolicDiscreteCallback) - isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) && - isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) + isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) && + isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) end function Base.hash(cb::SymbolicDiscreteCallback, s::UInt) s = hash(cb.condition, s) - s = cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s) - s = cb.initialize isa AbstractVector ? foldr(hash, cb.initialize, init = s) : hash(cb.initialize, s) - s = cb.finalize isa AbstractVector ? foldr(hash, cb.finalize, init = s) : hash(cb.finalize, s) + s = cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : + hash(cb.affects, s) + s = cb.initialize isa AbstractVector ? foldr(hash, cb.initialize, init = s) : + hash(cb.initialize, s) + s = cb.finalize isa AbstractVector ? foldr(hash, cb.finalize, init = s) : + hash(cb.finalize, s) s = hash(cb.reinitializealg, s) return s end @@ -405,7 +410,6 @@ function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) end - initialize_affects(cb::SymbolicDiscreteCallback) = cb.initialize function initialize_affects(cbs::Vector{SymbolicDiscreteCallback}) mapreduce(initialize_affects, vcat, cbs, init = Equation[]) @@ -418,10 +422,13 @@ end function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback function namespace_affects(af) - return af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s) + return af isa AbstractVector ? namespace_affect.(af, Ref(s)) : + namespace_affect(af, s) end - SymbolicDiscreteCallback(namespace_condition(condition(cb), s), namespace_affects(affects(cb)), - reinitializealg=cb.reinitializealg, initialize=namespace_affects(initialize_affects(cb)), finalize=namespace_affects(finalize_affects(cb))) + SymbolicDiscreteCallback( + namespace_condition(condition(cb), s), namespace_affects(affects(cb)), + reinitializealg = cb.reinitializealg, initialize = namespace_affects(initialize_affects(cb)), + finalize = namespace_affects(finalize_affects(cb))) end SymbolicDiscreteCallbacks(cb::Pair) = SymbolicDiscreteCallback[SymbolicDiscreteCallback(cb)] @@ -698,7 +705,7 @@ function generate_single_rootfinding_callback( end end user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : - (c, u, t, i) -> affect_function.initialize(i) + (c, u, t, i) -> affect_function.initialize(i) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing initfn = let save_idxs = save_idxs @@ -715,7 +722,8 @@ function generate_single_rootfinding_callback( return ContinuousCallback( cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, initialize = initfn, - finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i), + finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : + (c, u, t, i) -> affect_function.finalize(i), initializealg = reinitialization_alg(cb)) end @@ -742,8 +750,8 @@ function generate_vector_rootfinding_callback( affect_neg::Union{Function, Nothing}, initialize::Union{Function, Nothing}, finalize::Union{Function, Nothing}}[ - compile_affect_fn(cb, sys, dvs, ps, kwargs) - for cb in cbs] + compile_affect_fn(cb, sys, dvs, ps, kwargs) + for cb in cbs] cond = function (out, u, t, integ) rf_ip(out, u, parameter_values(integ), t) @@ -789,31 +797,37 @@ function generate_vector_rootfinding_callback( end initialize = nothing if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - initialize = handle_optional_setup_fn(map((cb, fn) -> begin - if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing - let save_idxs = save_idxs - if !isnothing(fn.initialize) - (i) -> begin - fn.initialize(i) - for idx in save_idxs - SciMLBase.save_discretes!(i, idx) - end - end - else - (i) -> begin - for idx in save_idxs - SciMLBase.save_discretes!(i, idx) + initialize = handle_optional_setup_fn( + map( + (cb, fn) -> begin + if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing + let save_idxs = save_idxs + if !isnothing(fn.initialize) + (i) -> begin + fn.initialize(i) + for idx in save_idxs + SciMLBase.save_discretes!(i, idx) + end + end + else + (i) -> begin + for idx in save_idxs + SciMLBase.save_discretes!(i, idx) + end + end end end + else + fn.initialize end - end - else - fn.initialize - end - end, cbs, affect_functions), SciMLBase.INITIALIZE_DEFAULT) - + end, + cbs, + affect_functions), + SciMLBase.INITIALIZE_DEFAULT) + else - initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT) + initialize = handle_optional_setup_fn( + map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT) end finalize = handle_optional_setup_fn( @@ -833,10 +847,13 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) if eq_neg_aff === eq_aff affect_neg = affect else - affect_neg = _compile_optional_affect(NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs...) + affect_neg = _compile_optional_affect( + NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs...) end - initialize = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) - finalize = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) + initialize = _compile_optional_affect( + NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) + finalize = _compile_optional_affect( + NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize) end @@ -928,7 +945,6 @@ function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...) compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) end - function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...) if isnothing(aff) || aff == default return nothing @@ -942,13 +958,15 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false}, postprocess_affect_expr!, kwargs...) - user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) - user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) + user_initfun = _compile_optional_affect( + NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) + user_finfun = _compile_optional_affect( + NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing - initfn = let + initfn = let save_idxs = save_idxs - initfun=user_initfun + initfun = user_initfun function (cb, u, t, integrator) if !isnothing(initfun) initfun(integrator) @@ -959,17 +977,21 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no end end else - initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i) + initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : + (_, _, _, i) -> user_initfun(i) end - finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i) + finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : + (_, _, _, i) -> user_finfun(i) if cond isa AbstractVector # Preset Time return PresetTimeCallback( - cond, as; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb)) + cond, as; initialize = initfn, finalize = finfun, + initializealg = reinitialization_alg(cb)) else # Periodic return PeriodicCallback( - as, cond; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb)) + as, cond; initialize = initfn, finalize = finfun, + initializealg = reinitialization_alg(cb)) end end @@ -983,8 +1005,10 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false}, postprocess_affect_expr!, kwargs...) - user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) - user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) + user_initfun = _compile_optional_affect( + NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...) + user_finfun = _compile_optional_affect( + NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing initfn = let save_idxs = save_idxs, initfun = user_initfun @@ -998,11 +1022,14 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = end end else - initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i) + initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : + (_, _, _, i) -> user_initfun(i) end - finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i) + finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : + (_, _, _, i) -> user_finfun(i) return DiscreteCallback( - c, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg(cb)) + c, as; initialize = initfn, finalize = finfun, + initializealg = reinitialization_alg(cb)) end end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 54180ad7c6..b58d5911f4 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -970,84 +970,90 @@ end @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end - @testset "Initialization" begin @variables x(t) seen = false - f = ModelingToolkit.FunctionalAffect(f=(i,u,p,c)->seen=true, sts=[], pars=[], discretes=[]) - cb1 = ModelingToolkit.SymbolicContinuousCallback([x ~ 0], Equation[], initialize=[x~1.5], finalize=f) + f = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> seen = true, sts = [], pars = [], discretes = []) + cb1 = ModelingToolkit.SymbolicContinuousCallback( + [x ~ 0], Equation[], initialize = [x ~ 1.5], finalize = f) @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; continuous_events = [cb1]) prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) - sol = solve(prob, Tsit5(); dtmax=0.01) + sol = solve(prob, Tsit5(); dtmax = 0.01) @test sol[x][1] ≈ 1.0 @test sol[x][2] ≈ 1.5 # the initialize affect has been applied @test seen == true - @variables x(t) seen = false - f = ModelingToolkit.FunctionalAffect(f=(i,u,p,c)->seen=true, sts=[], pars=[], discretes=[]) - cb1 = ModelingToolkit.SymbolicContinuousCallback([x ~ 0], Equation[], initialize=[x~1.5], finalize=f) - inited = false + f = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> seen = true, sts = [], pars = [], discretes = []) + cb1 = ModelingToolkit.SymbolicContinuousCallback( + [x ~ 0], Equation[], initialize = [x ~ 1.5], finalize = f) + inited = false finaled = false - a = ModelingToolkit.FunctionalAffect(f=(i,u,p,c)->inited=true, sts=[], pars=[], discretes=[]) - b = ModelingToolkit.FunctionalAffect(f=(i,u,p,c)->finaled=true, sts=[], pars=[], discretes=[]) - cb2= ModelingToolkit.SymbolicContinuousCallback([x ~ 0.1], Equation[], initialize=a, finalize=b) + a = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> inited = true, sts = [], pars = [], discretes = []) + b = ModelingToolkit.FunctionalAffect( + f = (i, u, p, c) -> finaled = true, sts = [], pars = [], discretes = []) + cb2 = ModelingToolkit.SymbolicContinuousCallback( + [x ~ 0.1], Equation[], initialize = a, finalize = b) @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; continuous_events = [cb1, cb2]) prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) sol = solve(prob, Tsit5()) @test sol[x][1] ≈ 1.0 @test sol[x][2] ≈ 1.5 # the initialize affect has been applied @test seen == true - @test inited == true + @test inited == true @test finaled == true #periodic - inited = false + inited = false finaled = false - cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, [x ~ 2], initialize=a, finalize=b) + cb3 = ModelingToolkit.SymbolicDiscreteCallback( + 1.0, [x ~ 2], initialize = a, finalize = b) @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) sol = solve(prob, Tsit5()) - @test inited == true + @test inited == true @test finaled == true - @test isapprox(sol[x][3], 0.0, atol=1e-9) + @test isapprox(sol[x][3], 0.0, atol = 1e-9) @test sol[x][4] ≈ 2.0 @test sol[x][5] ≈ 1.0 - seen = false - inited = false + inited = false finaled = false - cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, f, initialize=a, finalize=b) + cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, f, initialize = a, finalize = b) @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) sol = solve(prob, Tsit5()) @test seen == true - @test inited == true + @test inited == true #preset seen = false - inited = false + inited = false finaled = false - cb3 = ModelingToolkit.SymbolicDiscreteCallback([1.0], f, initialize=a, finalize=b) + cb3 = ModelingToolkit.SymbolicDiscreteCallback([1.0], f, initialize = a, finalize = b) @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) sol = solve(prob, Tsit5()) @test seen == true - @test inited == true + @test inited == true @test finaled == true #equational seen = false - inited = false + inited = false finaled = false - cb3 = ModelingToolkit.SymbolicDiscreteCallback(t == 1.0, f, initialize=a, finalize=b) + cb3 = ModelingToolkit.SymbolicDiscreteCallback( + t == 1.0, f, initialize = a, finalize = b) @mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3]) prob = ODEProblem(sys, [x => 1.0], (0.0, 2), []) - sol = solve(prob, Tsit5(); tstops=1.0) + sol = solve(prob, Tsit5(); tstops = 1.0) @test seen == true - @test inited == true + @test inited == true @test finaled == true end