diff --git a/src/remake.jl b/src/remake.jl index a078b3060..fcececb79 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -549,9 +549,16 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) for (k, v) in u0 idx = variable_index(prob, k) idx === nothing && continue - sym_to_idx[k] = idx - idx_to_sym[idx] = k - idx_to_val[idx] = v + if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() + idx = (idx,) + k = (k,) + v = (v,) + end + for (kk, vv, ii) in zip(k, v, idx) + sym_to_idx[kk] = ii + idx_to_sym[ii] = kk + idx_to_val[ii] = vv + end end for sym in vsyms haskey(sym_to_idx, sym) && continue @@ -586,9 +593,16 @@ function fill_p(prob, p; defs = nothing, use_defaults = false) for (k, v) in p idx = parameter_index(prob, k) idx === nothing && continue - sym_to_idx[k] = idx - idx_to_sym[idx] = k - idx_to_val[idx] = v + if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() + idx = (idx,) + k = (k,) + v = (v,) + end + for (kk, vv, ii) in zip(k, v, idx) + sym_to_idx[kk] = ii + idx_to_sym[ii] = kk + idx_to_val[ii] = vv + end end for sym in psyms haskey(sym_to_idx, sym) && continue diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 184fd3d50..eec0b6baf 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -191,3 +191,31 @@ sol = solve(prob, BFGS()) @test prob2.ps[P] == sign * 2.0 end end + +@testset "remake with Vector{Int} as index of array variable/parameter" begin + @parameters k[1:4] + @variables (V(t))[1:2] + function rhs!(du, u, p, t) + du[1] = p[1] - p[2] * u[1] + du[2] = p[3] - p[4] * u[2] + nothing + end + sys = SymbolCache(Dict(V => 1:2, V[1] => 1, V[2] => 2), + Dict(k => 1:4, k[1] => 1, k[2] => 2, k[3] => 3, k[4] => 4), t) + struct SCWrapper{S} + sys::S + end + SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sys + SymbolicIndexingInterface.variable_symbols(s::SCWrapper) = filter( + x -> symbolic_type(x) != ArraySymbolic(), variable_symbols(s.sys)) + SymbolicIndexingInterface.parameter_symbols(s::SCWrapper) = filter( + x -> symbolic_type(x) != ArraySymbolic(), parameter_symbols(s.sys)) + sys = SCWrapper(sys) + fn = ODEFunction(rhs!; sys) + oprob_scal_scal = ODEProblem(fn, [10.0, 20.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0]) + ps_vec = [k => [2.0, 3.0, 4.0, 5.0]] + u0_vec = [V => [1.5, 2.5]] + newoprob = remake(oprob_scal_scal; u0 = u0_vec, p = ps_vec) + @test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0] + @test newoprob[V] == [1.5, 2.5] +end