Skip to content

Commit

Permalink
refactor: update DiffEqArray to use new constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 27, 2023
1 parent 6785920 commit bac2ca4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 54 deletions.
47 changes: 4 additions & 43 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,39 +92,15 @@ end
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::Integer, continuity) where {deriv}
A = sol.interp(t, idxs, deriv, sol.prob.p, continuity)
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
if has_sys(sol.prob.f)
DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed),
typeof(p)}(A.u,
A.t,
sol.prob.f.sys,
observed,
p)
else
syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ?
[sol.prob.f.syms[idxs]] : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
end
return DiffEqArray(A.u, A.t, p, sol)
end
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector{<:Integer},
continuity) where {deriv}
A = sol.interp(t, idxs, deriv, sol.prob.p, continuity)
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
if has_sys(sol.prob.f)
DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed),
typeof(p)}(A.u,
A.t,
sol.prob.f.sys,
observed,
p)
else
syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ?
sol.prob.f.syms[idxs] : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
end
return DiffEqArray(A.u, A.t, p, sol)
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
Expand All @@ -144,31 +120,16 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
if has_sys(sol.prob.f)
return DiffEqArray(interp_sol[idxs], t, [idxs],
independent_variables(sol.prob.f.sys), observed, p)
else
return DiffEqArray(interp_sol[idxs], t, [idxs], getindepsym(sol), observed, p)
end
return DiffEqArray(interp_sol[idxs], t, p, sol)
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector, continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
if has_sys(sol.prob.f)
return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t,
idxs,
independent_variables(sol.prob.f.sys), observed, p)
else
return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t,
idxs,
getindepsym(sol), observed, p)
end
return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol)
end

function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
Expand Down
12 changes: 1 addition & 11 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,8 @@ end

# For augmenting system information to enable symbol based indexing of interpolated solutions
function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution) where {T, N, Q, B}
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
if has_sys(sol.prob.f)
DiffEqArray{T, N, Q, B, typeof(sol.prob.f.sys), typeof(observed), typeof(p)}(A.u,
A.t,
sol.prob.f.sys,
observed,
p)
else
syms = hasproperty(sol.prob.f, :syms) ? sol.prob.f.syms : nothing
DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p)
end
return DiffEqArray(A.u, A.t, p, sol)
end

# SymbolicIndexingInterface.jl
Expand Down

0 comments on commit bac2ca4

Please sign in to comment.