From dfdd9ad01735d6df52d8c09dbcfd948494c150dc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Sep 2024 17:02:08 +0530 Subject: [PATCH 1/2] fix: support `remake` with array symbolic whose index is an array of indices --- src/remake.jl | 26 ++++++++++++++++----- test/downstream/modelingtoolkit_remake.jl | 28 +++++++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index a078b3060..fcececb79 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -549,9 +549,16 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) for (k, v) in u0 idx = variable_index(prob, k) idx === nothing && continue - sym_to_idx[k] = idx - idx_to_sym[idx] = k - idx_to_val[idx] = v + if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() + idx = (idx,) + k = (k,) + v = (v,) + end + for (kk, vv, ii) in zip(k, v, idx) + sym_to_idx[kk] = ii + idx_to_sym[ii] = kk + idx_to_val[ii] = vv + end end for sym in vsyms haskey(sym_to_idx, sym) && continue @@ -586,9 +593,16 @@ function fill_p(prob, p; defs = nothing, use_defaults = false) for (k, v) in p idx = parameter_index(prob, k) idx === nothing && continue - sym_to_idx[k] = idx - idx_to_sym[idx] = k - idx_to_val[idx] = v + if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() + idx = (idx,) + k = (k,) + v = (v,) + end + for (kk, vv, ii) in zip(k, v, idx) + sym_to_idx[kk] = ii + idx_to_sym[ii] = kk + idx_to_val[ii] = vv + end end for sym in psyms haskey(sym_to_idx, sym) && continue diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 184fd3d50..eec0b6baf 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -191,3 +191,31 @@ sol = solve(prob, BFGS()) @test prob2.ps[P] == sign * 2.0 end end + +@testset "remake with Vector{Int} as index of array variable/parameter" begin + @parameters k[1:4] + @variables (V(t))[1:2] + function rhs!(du, u, p, t) + du[1] = p[1] - p[2] * u[1] + du[2] = p[3] - p[4] * u[2] + nothing + end + sys = SymbolCache(Dict(V => 1:2, V[1] => 1, V[2] => 2), + Dict(k => 1:4, k[1] => 1, k[2] => 2, k[3] => 3, k[4] => 4), t) + struct SCWrapper{S} + sys::S + end + SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sys + SymbolicIndexingInterface.variable_symbols(s::SCWrapper) = filter( + x -> symbolic_type(x) != ArraySymbolic(), variable_symbols(s.sys)) + SymbolicIndexingInterface.parameter_symbols(s::SCWrapper) = filter( + x -> symbolic_type(x) != ArraySymbolic(), parameter_symbols(s.sys)) + sys = SCWrapper(sys) + fn = ODEFunction(rhs!; sys) + oprob_scal_scal = ODEProblem(fn, [10.0, 20.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0]) + ps_vec = [k => [2.0, 3.0, 4.0, 5.0]] + u0_vec = [V => [1.5, 2.5]] + newoprob = remake(oprob_scal_scal; u0 = u0_vec, p = ps_vec) + @test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0] + @test newoprob[V] == [1.5, 2.5] +end From 5c029b71b9d4eee437a25a810c80928f44744b5e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Sep 2024 17:10:25 +0530 Subject: [PATCH 2/2] test: remove MTK as test dependency --- Project.toml | 4 +--- test/traits.jl | 17 ++++++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 215eb7527..4fb7694cd 100644 --- a/Project.toml +++ b/Project.toml @@ -71,7 +71,6 @@ LinearAlgebra = "1.10" Logging = "1.10" Makie = "0.20, 0.21" Markdown = "1.10" -ModelingToolkit = "8.75, 9" PartialFunctions = "1.1" PrecompileTools = "1.2" Preferences = "1.3" @@ -100,7 +99,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -118,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"] +test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff"] diff --git a/test/traits.jl b/test/traits.jl index b08e13a72..e67204200 100644 --- a/test/traits.jl +++ b/test/traits.jl @@ -1,6 +1,5 @@ using SciMLBase, Test -using ModelingToolkit, OrdinaryDiffEq, DataFrames -using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq, DataFrames, SymbolicIndexingInterface @test SciMLBase.Tables.isrowtable(ODESolution) @test SciMLBase.Tables.isrowtable(RODESolution) @@ -10,13 +9,13 @@ using ModelingToolkit: t_nounits as t, D_nounits as D @test !SciMLBase.Tables.isrowtable(SciMLBase.QuadratureSolution) @test !SciMLBase.Tables.isrowtable(SciMLBase.OptimizationSolution) -@variables x(t) = 1 -eqs = [D(x) ~ -x] -@named sys = ODESystem(eqs, t) -sys = complete(sys) -prob = ODEProblem(sys) -sol = solve(prob, Tsit5(), tspan = (0.0, 1.0)) +function rhs(u, p, t) + return -u +end +sys = SymbolCache([:x], Symbol[], :t) +prob = ODEProblem(ODEFunction(rhs; sys), [1.0], (0.0, 1.0)) +sol = solve(prob, Tsit5()) df = DataFrame(sol) @test size(df) == (length(sol.u), 2) @test df.timestamp == sol.t -@test df.x == sol[x] +@test df.x == sol[:x]