Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
BenChung committed Oct 29, 2024
1 parent eb14e35 commit 95b0ecc
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 88 deletions.
149 changes: 88 additions & 61 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ struct SymbolicContinuousCallback
finalize = NULL_AFFECT,
rootfind = SciMLBase.LeftRootFind,
reinitializealg = SciMLBase.CheckInit())
new(eqs, initialize, finalize, make_affect(affect),
new(eqs, initialize, finalize, make_affect(affect),
make_affect(affect_neg), rootfind, reinitializealg)
end # Default affect to nothing
end
Expand Down Expand Up @@ -227,18 +227,19 @@ function SymbolicContinuousCallback(args...)
end # wrap eq in vector
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
initialize=NULL_AFFECT, finalize=NULL_AFFECT,
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
initialize = NULL_AFFECT, finalize = NULL_AFFECT,
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
SymbolicContinuousCallback(
eqs = [eqs], affect = affect, affect_neg = affect_neg,
initialize=initialize, finalize=finalize, rootfind = rootfind)
eqs = [eqs], affect = affect, affect_neg = affect_neg,
initialize = initialize, finalize = finalize, rootfind = rootfind)
end
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
affect_neg = affect, initialize=NULL_AFFECT, finalize=NULL_AFFECT,
affect_neg = affect, initialize = NULL_AFFECT, finalize = NULL_AFFECT,
rootfind = SciMLBase.LeftRootFind)
SymbolicContinuousCallback(
eqs = eqs, affect = affect, affect_neg = affect_neg, initialize=initialize, finalize=finalize, rootfind = rootfind)
eqs = eqs, affect = affect, affect_neg = affect_neg,
initialize = initialize, finalize = finalize, rootfind = rootfind)
end

SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
Expand Down Expand Up @@ -334,10 +335,11 @@ struct SymbolicDiscreteCallback

function SymbolicDiscreteCallback(
condition, affects = NULL_AFFECT; reinitializealg = SciMLBase.CheckInit(),
initialize=NULL_AFFECT, finalize=NULL_AFFECT)
initialize = NULL_AFFECT, finalize = NULL_AFFECT)
c = scalarize_condition(condition)
a = scalarize_affects(affects)
new(c, a, scalarize_affects(initialize), scalarize_affects(finalize), reinitializealg)
new(c, a, scalarize_affects(initialize),
scalarize_affects(finalize), reinitializealg)
end # Default affect to nothing
end

Expand Down Expand Up @@ -376,14 +378,17 @@ function Base.show(io::IO, db::SymbolicDiscreteCallback)
end

function Base.:(==)(e1::SymbolicDiscreteCallback, e2::SymbolicDiscreteCallback)
isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) &&
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)
isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) &&
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)
end
function Base.hash(cb::SymbolicDiscreteCallback, s::UInt)
s = hash(cb.condition, s)
s = cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s)
s = cb.initialize isa AbstractVector ? foldr(hash, cb.initialize, init = s) : hash(cb.initialize, s)
s = cb.finalize isa AbstractVector ? foldr(hash, cb.finalize, init = s) : hash(cb.finalize, s)
s = cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) :
hash(cb.affects, s)
s = cb.initialize isa AbstractVector ? foldr(hash, cb.initialize, init = s) :
hash(cb.initialize, s)
s = cb.finalize isa AbstractVector ? foldr(hash, cb.finalize, init = s) :
hash(cb.finalize, s)
s = hash(cb.reinitializealg, s)
return s
end
Expand All @@ -405,7 +410,6 @@ function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback})
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
end


initialize_affects(cb::SymbolicDiscreteCallback) = cb.initialize
function initialize_affects(cbs::Vector{SymbolicDiscreteCallback})
mapreduce(initialize_affects, vcat, cbs, init = Equation[])
Expand All @@ -418,10 +422,13 @@ end

function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
function namespace_affects(af)
return af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
return af isa AbstractVector ? namespace_affect.(af, Ref(s)) :
namespace_affect(af, s)
end
SymbolicDiscreteCallback(namespace_condition(condition(cb), s), namespace_affects(affects(cb)),
reinitializealg=cb.reinitializealg, initialize=namespace_affects(initialize_affects(cb)), finalize=namespace_affects(finalize_affects(cb)))
SymbolicDiscreteCallback(
namespace_condition(condition(cb), s), namespace_affects(affects(cb)),
reinitializealg = cb.reinitializealg, initialize = namespace_affects(initialize_affects(cb)),
finalize = namespace_affects(finalize_affects(cb)))
end

SymbolicDiscreteCallbacks(cb::Pair) = SymbolicDiscreteCallback[SymbolicDiscreteCallback(cb)]
Expand Down Expand Up @@ -698,7 +705,7 @@ function generate_single_rootfinding_callback(
end
end
user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT :
(c, u, t, i) -> affect_function.initialize(i)
(c, u, t, i) -> affect_function.initialize(i)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
initfn = let save_idxs = save_idxs
Expand All @@ -715,7 +722,8 @@ function generate_single_rootfinding_callback(
return ContinuousCallback(
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
initialize = initfn,
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i),
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT :
(c, u, t, i) -> affect_function.finalize(i),
initializealg = reinitialization_alg(cb))
end

Expand All @@ -742,8 +750,8 @@ function generate_vector_rootfinding_callback(
affect_neg::Union{Function, Nothing},
initialize::Union{Function, Nothing},
finalize::Union{Function, Nothing}}[
compile_affect_fn(cb, sys, dvs, ps, kwargs)
for cb in cbs]
compile_affect_fn(cb, sys, dvs, ps, kwargs)
for cb in cbs]

cond = function (out, u, t, integ)
rf_ip(out, u, parameter_values(integ), t)
Expand Down Expand Up @@ -789,31 +797,37 @@ function generate_vector_rootfinding_callback(
end
initialize = nothing
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
initialize = handle_optional_setup_fn(map((cb, fn) -> begin
if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
let save_idxs = save_idxs
if !isnothing(fn.initialize)
(i) -> begin
fn.initialize(i)
for idx in save_idxs
SciMLBase.save_discretes!(i, idx)
end
end
else
(i) -> begin
for idx in save_idxs
SciMLBase.save_discretes!(i, idx)
initialize = handle_optional_setup_fn(
map(
(cb, fn) -> begin
if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
let save_idxs = save_idxs
if !isnothing(fn.initialize)
(i) -> begin
fn.initialize(i)
for idx in save_idxs
SciMLBase.save_discretes!(i, idx)
end
end
else
(i) -> begin
for idx in save_idxs
SciMLBase.save_discretes!(i, idx)
end
end
end
end
else
fn.initialize
end
end
else
fn.initialize
end
end, cbs, affect_functions), SciMLBase.INITIALIZE_DEFAULT)

end,
cbs,
affect_functions),
SciMLBase.INITIALIZE_DEFAULT)

else
initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
initialize = handle_optional_setup_fn(
map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
end

finalize = handle_optional_setup_fn(
Expand All @@ -833,10 +847,13 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
if eq_neg_aff === eq_aff
affect_neg = affect
else
affect_neg = _compile_optional_affect(NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs...)
affect_neg = _compile_optional_affect(
NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs...)
end
initialize = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
finalize = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
initialize = _compile_optional_affect(
NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
finalize = _compile_optional_affect(
NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
end

Expand Down Expand Up @@ -928,7 +945,6 @@ function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
end


function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...)
if isnothing(aff) || aff == default
return nothing
Expand All @@ -942,13 +958,15 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
postprocess_affect_expr!, kwargs...)

user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_initfun = _compile_optional_affect(
NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_finfun = _compile_optional_affect(
NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
initfn = let
initfn = let
save_idxs = save_idxs
initfun=user_initfun
initfun = user_initfun
function (cb, u, t, integrator)
if !isnothing(initfun)
initfun(integrator)
Expand All @@ -959,17 +977,21 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
end
end
else
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i)
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT :
(_, _, _, i) -> user_initfun(i)
end
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i)
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT :
(_, _, _, i) -> user_finfun(i)
if cond isa AbstractVector
# Preset Time
return PresetTimeCallback(
cond, as; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb))
cond, as; initialize = initfn, finalize = finfun,
initializealg = reinitialization_alg(cb))
else
# Periodic
return PeriodicCallback(
as, cond; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb))
as, cond; initialize = initfn, finalize = finfun,
initializealg = reinitialization_alg(cb))
end
end

Expand All @@ -983,8 +1005,10 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
postprocess_affect_expr!, kwargs...)

user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_initfun = _compile_optional_affect(
NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
user_finfun = _compile_optional_affect(
NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
initfn = let save_idxs = save_idxs, initfun = user_initfun
Expand All @@ -998,11 +1022,14 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
end
end
else
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i)
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT :
(_, _, _, i) -> user_initfun(i)
end
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i)
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT :
(_, _, _, i) -> user_finfun(i)
return DiscreteCallback(
c, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg(cb))
c, as; initialize = initfn, finalize = finfun,
initializealg = reinitialization_alg(cb))
end
end

Expand Down
Loading

0 comments on commit 95b0ecc

Please sign in to comment.