diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index dc1d29e9a..012d0ef93 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -146,6 +146,18 @@ function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense dense, tslocation, stats, alg_choice, retcode, resid, original) end +error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing +function error_if_observed_derivative(sys, idx, ::Type) + if symbolic_type(idx) != NotSymbolic() && is_observed(sys, idx) || + symbolic_type(idx) == NotSymbolic() && any(x -> is_observed(sys, x), idx) + error(""" + Cannot interpolate derivatives of observed variables. A possible solution could be + interpolating the symbolic expression that evaluates to the derivative of the + observed variable or using DataInterpolations.jl. + """) + end +end + function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} sol(t, deriv, idxs, continuity) @@ -197,6 +209,7 @@ end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") + error_if_observed_derivative(sol, idxs, deriv) if is_parameter(sol, idxs) return getp(sol, idxs)(sol) else @@ -210,6 +223,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect any(isequal(NotSymbolic()), symbolic_type.(idxs)) error("Incorrect specification of `idxs`") end + error_if_observed_derivative(sol, idxs, deriv) interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol) first(interp_sol[idxs]) end @@ -217,6 +231,7 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") + error_if_observed_derivative(sol, idxs, deriv) if is_parameter(sol, idxs) return getp(sol, idxs)(sol) else @@ -230,6 +245,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector, continuity) where {deriv} all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") + error_if_observed_derivative(sol, idxs, deriv) interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing indexed_sol = interp_sol[idxs] diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index f30922a62..ca4823b55 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -371,3 +371,18 @@ sol = solve(prob, Tsit5()) @test sol.ps[a] ≈ 1 @test sol.ps[b] ≈ 100 end + +# Issue https://github.com/SciML/ModelingToolkit.jl/issues/2697 +@testset "Interpolation of derivative of observed variables" begin + @variables x(t) y(t) z(t) w(t)[1:2] + @named sys = ODESystem( + [D(x) ~ 1, y ~ x^2, z ~ 2y^2 + 3x, w[1] ~ x + y + z, w[2] ~ z * x * y], t) + sys = structural_simplify(sys) + prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0)) + sol = solve(prob, Tsit5()) + @test_throws ErrorException sol(1.0, Val{1}, idxs = y) + @test_throws ErrorException sol(1.0, Val{1}, idxs = [y, z]) + @test_throws ErrorException sol(1.0, Val{1}, idxs = w) + @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w]) + @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y]) +end