diff --git a/Project.toml b/Project.toml index dcd25c4a..b8b9127b 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ StaticArrays = "1.6" StaticArraysCore = "1.4" Statistics = "1.10" StructArrays = "0.6.11" -SymbolicIndexingInterface = "0.3.23" +SymbolicIndexingInterface = "0.3.20" Tables = "1.11" Test = "1" Tracker = "0.2.15" diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 36f95ed2..2c13a573 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -60,13 +60,11 @@ A[1, :] # all time periods for f(t) A.t ``` """ -mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <: - AbstractDiffEqArray{T, N, A} +mutable struct DiffEqArray{T, N, A, B, F, S} <: AbstractDiffEqArray{T, N, A} u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}} t::B p::F sys::S - discretes::D end ### Abstract Interface struct AllObserved @@ -176,32 +174,29 @@ function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p = nothing, - sys = nothing; discretes = nothing) where {T, N} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, + sys = nothing) where {T, N} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec, ts, p, - sys, - discretes) + sys) end # ambiguity resolution function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, ::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec, + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec, ts, nothing, - nothing, nothing) end function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, - ::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec, + ::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec, ts, p, - nothing, - discretes) + nothing) end # Assume that the first element is representative of all other elements @@ -211,8 +206,7 @@ function DiffEqArray(vec::AbstractVector, sys = nothing; variables = nothing, parameters = nothing, - independent_variables = nothing, - discretes = nothing) + independent_variables = nothing) sys = something(sys, SymbolCache(something(variables, []), something(parameters, []), @@ -225,13 +219,11 @@ function DiffEqArray(vec::AbstractVector, typeof(vec), typeof(ts), typeof(p), - typeof(sys), - typeof(discretes) + typeof(sys) }(vec, ts, p, - sys, - discretes) + sys) end function DiffEqArray(vec::AbstractVector{VT}, @@ -240,8 +232,7 @@ function DiffEqArray(vec::AbstractVector{VT}, sys = nothing; variables = nothing, parameters = nothing, - independent_variables = nothing, - discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} + independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} sys = something(sys, SymbolCache(something(variables, []), something(parameters, []), @@ -252,30 +243,18 @@ function DiffEqArray(vec::AbstractVector{VT}, typeof(vec), typeof(ts), typeof(p), - typeof(sys), - typeof(discretes), + typeof(sys) }(vec, ts, p, - sys, - discretes) + sys) end -has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes) -get_discretes(x) = getfield(x, :discretes) - SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries() -function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B, - F, S, D}}) where {T, N, A, B, F, S, D <: ParameterIndexingProxy} - Timeseries() -end SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys -function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::AbstractDiffEqArray) - return get_discretes(A) -end Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A)) Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian() @@ -384,18 +363,39 @@ end # Symbolic Indexing Methods for (symtype, elsymtype, valtype, errcheck) in [ - (ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))), - (ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))), + (ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), + (ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), (NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray}, - :(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))), + :(all(x -> is_parameter(A, x), sym))), ] - @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype, arg...) - if $errcheck - throw(ParameterIndexingError(sym)) - end - getu(A, sym)(A, arg...) +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype) + if $errcheck + throw(ParameterIndexingError(sym)) end + getu(A, sym)(A) +end +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, arg) + if $errcheck + throw(ParameterIndexingError(sym)) + end + getu(A, sym)(A, arg) +end +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}}) + if $errcheck + throw(ParameterIndexingError(sym)) + end + getu(A, sym).((A,), arg) +end +@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, ::Colon) + if $errcheck + throw(ParameterIndexingError(sym)) + end + getu(A, sym)(A) +end end Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,