Skip to content

Commit

Permalink
fix: error when interpolating derivatives of observed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 6, 2024
1 parent 289d9e2 commit 5860b14
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense
dense, tslocation, stats, alg_choice, retcode, resid, original)
end

error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
function error_if_observed_derivative(sys, idx, ::Type)
if symbolic_type(idx) != NotSymbolic() && is_observed(sys, idx) ||
symbolic_type(idx) == NotSymbolic() && any(x -> is_observed(sys, x), idx)
error("""

Check warning on line 153 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L149-L153

Added lines #L149 - L153 were not covered by tests
Cannot interpolate derivatives of observed variables. A possible solution could be
interpolating the symbolic expression that evaluates to the derivative of the
observed variable or using DataInterpolations.jl.
""")
end
end

function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
continuity = :left) where {deriv}
sol(t, deriv, idxs, continuity)
Expand Down Expand Up @@ -197,6 +209,7 @@ end
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 212 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L212

Added line #L212 was not covered by tests
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
Expand All @@ -210,13 +223,15 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
any(isequal(NotSymbolic()), symbolic_type.(idxs))
error("Incorrect specification of `idxs`")
end
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 226 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L226

Added line #L226 was not covered by tests
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
first(interp_sol[idxs])
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 234 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L234

Added line #L234 was not covered by tests
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
else
Expand All @@ -230,6 +245,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector, continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)

Check warning on line 248 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L248

Added line #L248 was not covered by tests
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
indexed_sol = interp_sol[idxs]
Expand Down
15 changes: 15 additions & 0 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,18 @@ sol = solve(prob, Tsit5())
@test sol.ps[a] 1
@test sol.ps[b] 100
end

# Issue https://github.com/SciML/ModelingToolkit.jl/issues/2697
@testset "Interpolation of derivative of observed variables" begin
@variables x(t) y(t) z(t) w(t)[1:2]
@named sys = ODESystem(
[D(x) ~ 1, y ~ x^2, z ~ 2y^2 + 3x, w[1] ~ x + y + z, w[2] ~ z * x * y], t)
sys = structural_simplify(sys)
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0))
sol = solve(prob, Tsit5())
@test_throws ErrorException sol(1.0, Val{1}, idxs = y)
@test_throws ErrorException sol(1.0, Val{1}, idxs = [y, z])
@test_throws ErrorException sol(1.0, Val{1}, idxs = w)
@test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w])
@test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y])
end

0 comments on commit 5860b14

Please sign in to comment.