Skip to content

Commit

Permalink
fix: support remake with array symbolic whose index is an array of …
Browse files Browse the repository at this point in the history
…indices
  • Loading branch information
AayushSabharwal committed Sep 27, 2024
1 parent 06864fd commit dfdd9ad
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,16 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
for (k, v) in u0
idx = variable_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in vsyms
haskey(sym_to_idx, sym) && continue
Expand Down Expand Up @@ -586,9 +593,16 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
for (k, v) in p
idx = parameter_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in psyms
haskey(sym_to_idx, sym) && continue
Expand Down
28 changes: 28 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,31 @@ sol = solve(prob, BFGS())
@test prob2.ps[P] == sign * 2.0
end
end

@testset "remake with Vector{Int} as index of array variable/parameter" begin
@parameters k[1:4]
@variables (V(t))[1:2]
function rhs!(du, u, p, t)
du[1] = p[1] - p[2] * u[1]
du[2] = p[3] - p[4] * u[2]
nothing
end
sys = SymbolCache(Dict(V => 1:2, V[1] => 1, V[2] => 2),
Dict(k => 1:4, k[1] => 1, k[2] => 2, k[3] => 3, k[4] => 4), t)
struct SCWrapper{S}
sys::S
end
SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sys
SymbolicIndexingInterface.variable_symbols(s::SCWrapper) = filter(
x -> symbolic_type(x) != ArraySymbolic(), variable_symbols(s.sys))
SymbolicIndexingInterface.parameter_symbols(s::SCWrapper) = filter(
x -> symbolic_type(x) != ArraySymbolic(), parameter_symbols(s.sys))
sys = SCWrapper(sys)
fn = ODEFunction(rhs!; sys)
oprob_scal_scal = ODEProblem(fn, [10.0, 20.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0])
ps_vec = [k => [2.0, 3.0, 4.0, 5.0]]
u0_vec = [V => [1.5, 2.5]]
newoprob = remake(oprob_scal_scal; u0 = u0_vec, p = ps_vec)
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
@test newoprob[V] == [1.5, 2.5]
end

0 comments on commit dfdd9ad

Please sign in to comment.