From 9ea5b75249f2bc93101bbeb23cd91dcf2e304da7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 2 Jan 2024 14:38:17 +0530 Subject: [PATCH] fix: interpolation using symbolic idxs for RODESolution --- src/solutions/rode_solutions.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index b6ec13358..14ca2f200 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -63,6 +63,22 @@ TruncatedStacktraces.@truncate_stacktrace RODESolution 1 2 function (sol::RODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} + if idxs !== nothing + if !(idxs isa Union{<:AbstractArray, <:Tuple}) + idxs = [idxs] + end + idxs = map(idxs) do idx + if symbolic_type(idx) === NotSymbolic() + return idx + elseif symbolic_type(idx) === ScalarSymbolic() + return variable_index(sol, idx) + else + return variable_index.((sol,), collect(idx)) + end + end + any(i === nothing for i in idxs) && error("All idxs must be variables") + idxs = reduce(vcat, idxs) + end sol.interp(t, idxs, deriv, sol.prob.p, continuity) end function (sol::RODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing,