Skip to content

Commit

Permalink
fix: fix interpolation of array symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 6, 2024
1 parent a0fab7a commit 289d9e2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,10 @@ end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
if symbolic_type(idxs) == NotSymbolic() &&
any(isequal(NotSymbolic()), symbolic_type.(idxs))
error("Incorrect specification of `idxs`")
end
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
first(interp_sol[idxs])
end
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ sol = solve(prob, Tsit5())
@test sol[x] isa Vector{<:Vector}
@test sol[@nonamespace sys.x] isa Vector{<:Vector}
@test sol.ps[p] == [1, 2, 3]
# interpolation of array variables
@test sol(1.0, idxs = x) == [sol(1.0, idxs = x[i]) for i in 1:3]

x_idx = variable_index.((sys,), [x[1], x[2], x[3]])
y_idx = variable_index(sys, y)
Expand Down

0 comments on commit 289d9e2

Please sign in to comment.