Skip to content

Commit

Permalink
refactor: fix solution interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 28, 2023
1 parent 2705271 commit a6d38fa
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,30 @@ end
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)

Check warning on line 183 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
else
return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]

Check warning on line 185 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L185

Added line #L185 was not covered by tests
end
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
[first(interp_sol[idx]) for idx in idxs]
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]

Check warning on line 193 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L193

Added line #L193 was not covered by tests
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)

Check warning on line 200 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L199-L200

Added lines #L199 - L200 were not covered by tests
else
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)

Check warning on line 204 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L202-L204

Added lines #L202 - L204 were not covered by tests
end
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
Expand Down

0 comments on commit a6d38fa

Please sign in to comment.