Skip to content

Commit

Permalink
Merge pull request #801 from AayushSabharwal/as/remake-arridx
Browse files Browse the repository at this point in the history
fix: support remake with array symbolic whose index is an array of indices
  • Loading branch information
ChrisRackauckas authored Sep 27, 2024
2 parents 06864fd + 5c029b7 commit 561c781
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 18 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ LinearAlgebra = "1.10"
Logging = "1.10"
Makie = "0.20, 0.21"
Markdown = "1.10"
ModelingToolkit = "8.75, 9"
PartialFunctions = "1.1"
PrecompileTools = "1.2"
Preferences = "1.3"
Expand Down Expand Up @@ -100,7 +99,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -118,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff"]
26 changes: 20 additions & 6 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,16 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
for (k, v) in u0
idx = variable_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in vsyms
haskey(sym_to_idx, sym) && continue
Expand Down Expand Up @@ -586,9 +593,16 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
for (k, v) in p
idx = parameter_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in psyms
haskey(sym_to_idx, sym) && continue
Expand Down
28 changes: 28 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,31 @@ sol = solve(prob, BFGS())
@test prob2.ps[P] == sign * 2.0
end
end

@testset "remake with Vector{Int} as index of array variable/parameter" begin
@parameters k[1:4]
@variables (V(t))[1:2]
function rhs!(du, u, p, t)
du[1] = p[1] - p[2] * u[1]
du[2] = p[3] - p[4] * u[2]
nothing
end
sys = SymbolCache(Dict(V => 1:2, V[1] => 1, V[2] => 2),
Dict(k => 1:4, k[1] => 1, k[2] => 2, k[3] => 3, k[4] => 4), t)
struct SCWrapper{S}
sys::S
end
SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sys
SymbolicIndexingInterface.variable_symbols(s::SCWrapper) = filter(
x -> symbolic_type(x) != ArraySymbolic(), variable_symbols(s.sys))
SymbolicIndexingInterface.parameter_symbols(s::SCWrapper) = filter(
x -> symbolic_type(x) != ArraySymbolic(), parameter_symbols(s.sys))
sys = SCWrapper(sys)
fn = ODEFunction(rhs!; sys)
oprob_scal_scal = ODEProblem(fn, [10.0, 20.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0])
ps_vec = [k => [2.0, 3.0, 4.0, 5.0]]
u0_vec = [V => [1.5, 2.5]]
newoprob = remake(oprob_scal_scal; u0 = u0_vec, p = ps_vec)
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
@test newoprob[V] == [1.5, 2.5]
end
17 changes: 8 additions & 9 deletions test/traits.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using SciMLBase, Test
using ModelingToolkit, OrdinaryDiffEq, DataFrames
using ModelingToolkit: t_nounits as t, D_nounits as D
using OrdinaryDiffEq, DataFrames, SymbolicIndexingInterface

@test SciMLBase.Tables.isrowtable(ODESolution)
@test SciMLBase.Tables.isrowtable(RODESolution)
Expand All @@ -10,13 +9,13 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
@test !SciMLBase.Tables.isrowtable(SciMLBase.QuadratureSolution)
@test !SciMLBase.Tables.isrowtable(SciMLBase.OptimizationSolution)

@variables x(t) = 1
eqs = [D(x) ~ -x]
@named sys = ODESystem(eqs, t)
sys = complete(sys)
prob = ODEProblem(sys)
sol = solve(prob, Tsit5(), tspan = (0.0, 1.0))
function rhs(u, p, t)
return -u
end
sys = SymbolCache([:x], Symbol[], :t)
prob = ODEProblem(ODEFunction(rhs; sys), [1.0], (0.0, 1.0))
sol = solve(prob, Tsit5())
df = DataFrame(sol)
@test size(df) == (length(sol.u), 2)
@test df.timestamp == sol.t
@test df.x == sol[x]
@test df.x == sol[:x]

0 comments on commit 561c781

Please sign in to comment.