Skip to content

Commit

Permalink
forgot a file
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Feb 27, 2024
1 parent 0501bf5 commit 60cdf22
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, SymbolicIndexingInterface, Test
using Optimization, OptimizationOptimJL
# compat for MTKv8 and v9
unknowns = isdefined(ModelingToolkit, :states) ? ModelingToolkit.states :
ModelingToolkit.unknowns
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters t σ ρ β
@parameters σ ρ β
@variables x(t) y(t) z(t)
D = Differential(t)

Expand All @@ -19,8 +17,7 @@ eqs = [D(x) ~ σ * (y - x),
@variables a(t) α(t)
connections = [0 ~ lorenz1.x + lorenz2.y + a * γ,
α ~ 2lorenz1.x + a * γ]
@named sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])
sys_simplified = structural_simplify(sys)
@mtkbuild sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])

u0 = [lorenz1.x => 1.0,
lorenz1.y => 0.0,
Expand All @@ -39,7 +36,7 @@ p = [lorenz1.σ => 10.0,
γ => 2.0]

tspan = (0.0, 100.0)
prob = ODEProblem(sys_simplified, u0, tspan, p)
prob = ODEProblem(sys, u0, tspan, p)
integ = init(prob, Rodas4())
sol = solve(prob, Rodas4())

Expand Down Expand Up @@ -135,7 +132,7 @@ sol1 = sol(0.0:1.0:10.0)

sol2 = sol(0.1)
@test sol2 isa Vector
@test length(sol2) == length(unknowns(sys_simplified))
@test length(sol2) == length(unknowns(sys))
@test first(sol2) isa Real

sol3 = sol(0.0:1.0:10.0, idxs = [lorenz1.x, lorenz2.x])
Expand Down Expand Up @@ -191,9 +188,9 @@ sol10 = sol(0.1, idxs = 2)
@test sol10 isa Real

@test is_timeseries(sol) == Timeseries()
getx = getu(sys_simplified, lorenz1.x)
get_arr = getu(sys_simplified, [lorenz1.x, lorenz2.x])
get_tuple = getu(sys_simplified, (lorenz1.x, lorenz2.x))
getx = getu(sys, lorenz1.x)
get_arr = getu(sys, [lorenz1.x, lorenz2.x])
get_tuple = getu(sys, (lorenz1.x, lorenz2.x))
get_obs = getu(sol, lorenz1.x + lorenz2.x) # can't use sys for observed
get_obs_arr = getu(sol, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y])
l1x_idx = variable_index(sol, lorenz1.x)
Expand Down

0 comments on commit 60cdf22

Please sign in to comment.