Skip to content

Commit

Permalink
Merge pull request SciML#2355 from SciML/fb/discrete_timing
Browse files Browse the repository at this point in the history
add component-based hybrid system test
  • Loading branch information
YingboMa authored Jan 25, 2024
2 parents 7685996 + b91fafe commit eb33a8a
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 51 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ julia = "1.9"
[extras]
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
ControlSystemsMTK = "687d7614-c7e5-45fc-bfc3-9ee385575c88"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -136,4 +137,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]
42 changes: 35 additions & 7 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
offset = length(appended_parameters)
affect_funs = []
init_funs = []
svs = []
clocks = TimeDomain[]
for (i, (sys, input)) in enumerate(zip(syss, inputs))
Expand Down Expand Up @@ -202,6 +203,18 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
push!(save_vec.args, :(p[$(input_offset + i)]))
end
empty_disc = isempty(disc_range)

disc_init = :(function (p, t)
d2c_obs = $disc_to_cont_obs
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
copyto!(d2c_view, d2c_obs(disc_state, p, t))
end)

# @show disc_to_cont_idxs
# @show cont_to_disc_idxs
# @show disc_range

affect! = :(function (integrator, saved_values)
@unpack u, p, t = integrator
c2d_obs = $cont_to_disc_obs
Expand All @@ -212,27 +225,42 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
disc = $disc
# Write continuous into to discrete: handles `Sample`
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
# Write discrete into to continuous
# get old discrete states
copyto!(d2c_view, d2c_obs(disc_state, p, t))

push!(saved_values.t, t)
push!(saved_values.saveval, $save_vec)
# update discrete states

# Write continuous into to discrete: handles `Sample`
# Write discrete into to continuous
# Update discrete states

# At a tick, c2d must come first
# state update comes in the middle
# d2c comes last
# @show t
# @show "incoming", p
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
# @show "after c2d", p
$empty_disc || disc(disc_state, disc_state, p, t)
# @show "after state update", p
copyto!(d2c_view, d2c_obs(disc_state, p, t))
# @show "after d2c", p
end)
sv = SavedValues(Float64, Vector{Float64})
push!(affect_funs, affect!)
push!(init_funs, disc_init)
push!(svs, sv)
end
if eval_expression
affects = map(affect_funs) do a
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
end
inits = map(init_funs) do a
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
end
else
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
end
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
return affects, clocks, svs, appended_parameters, defaults
return affects, inits, clocks, svs, appended_parameters, defaults
end
34 changes: 28 additions & 6 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -945,8 +945,9 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
has_difference = has_difference,
check_length, kwargs...)
cbs = process_events(sys; callback, has_difference, kwargs...)
inits = []
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
Expand Down Expand Up @@ -976,7 +977,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
if !isempty(inits)
for init in inits
init(prob.p, tspan[1])
end
end
prob
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

Expand Down Expand Up @@ -1045,8 +1052,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h = h_oop
u0 = h(p, tspan[1])
cbs = process_events(sys; callback, has_difference, kwargs...)
inits = []
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
Expand Down Expand Up @@ -1075,7 +1083,13 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
prob = DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
if !isempty(inits)
for init in inits
init(prob.p, tspan[1])
end
end
prob
end

function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
Expand All @@ -1099,8 +1113,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h(p, t) = h_oop(p, t)
u0 = h(p, tspan[1])
cbs = process_events(sys; callback, has_difference, kwargs...)
inits = []
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
Expand Down Expand Up @@ -1140,8 +1155,15 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
else
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
end
SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype =
prob = SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
noise_rate_prototype =
noise_rate_prototype, kwargs1..., kwargs...)
if !isempty(inits)
for init in inits
init(prob.p, tspan[1])
end
end
prob
end

"""
Expand Down
Loading

0 comments on commit eb33a8a

Please sign in to comment.