From bac2ca45dcbd15f1d3ca6dbfba4ca8923c445d7a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Oct 2023 13:25:53 +0530 Subject: [PATCH] refactor: update DiffEqArray to use new constructors --- src/solutions/ode_solutions.jl | 47 +++-------------------------- src/solutions/solution_interface.jl | 12 +------- 2 files changed, 5 insertions(+), 54 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index de46944b7e..8a763f62b8 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -92,39 +92,15 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Integer, continuity) where {deriv} A = sol.interp(t, idxs, deriv, sol.prob.p, continuity) - observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - if has_sys(sol.prob.f) - DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed), - typeof(p)}(A.u, - A.t, - sol.prob.f.sys, - observed, - p) - else - syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ? - [sol.prob.f.syms[idxs]] : nothing - DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) - end + return DiffEqArray(A.u, A.t, p, sol) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} A = sol.interp(t, idxs, deriv, sol.prob.p, continuity) - observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - if has_sys(sol.prob.f) - DiffEqArray{typeof(A).parameters[1:4]..., typeof(sol.prob.f.sys), typeof(observed), - typeof(p)}(A.u, - A.t, - sol.prob.f.sys, - observed, - p) - else - syms = hasproperty(sol.prob.f, :syms) && sol.prob.f.syms !== nothing ? - sol.prob.f.syms[idxs] : nothing - DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) - end + return DiffEqArray(A.u, A.t, p, sol) end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, @@ -144,31 +120,16 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) - observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - if has_sys(sol.prob.f) - return DiffEqArray(interp_sol[idxs], t, [idxs], - independent_variables(sol.prob.f.sys), observed, p) - else - return DiffEqArray(interp_sol[idxs], t, [idxs], getindepsym(sol), observed, p) - end + return DiffEqArray(interp_sol[idxs], t, p, sol) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector, continuity) where {deriv} all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) - observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - if has_sys(sol.prob.f) - return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, - idxs, - independent_variables(sol.prob.f.sys), observed, p) - else - return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, - idxs, - getindepsym(sol), observed, p) - end + return DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol) end function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 8abc37be2d..afe4038cfc 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -19,18 +19,8 @@ end # For augmenting system information to enable symbol based indexing of interpolated solutions function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution) where {T, N, Q, B} - observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - if has_sys(sol.prob.f) - DiffEqArray{T, N, Q, B, typeof(sol.prob.f.sys), typeof(observed), typeof(p)}(A.u, - A.t, - sol.prob.f.sys, - observed, - p) - else - syms = hasproperty(sol.prob.f, :syms) ? sol.prob.f.syms : nothing - DiffEqArray(A.u, A.t, syms, getindepsym(sol), observed, p) - end + return DiffEqArray(A.u, A.t, p, sol) end # SymbolicIndexingInterface.jl