From 72ceb8289f06d285fdf97d66908284bc6a30943b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:47:02 +0530 Subject: [PATCH] fixup! feat: add parameter timeseries support to `AbstractDiffEqArray` --- src/vector_of_array.jl | 105 +++++------------------------------------ 1 file changed, 13 insertions(+), 92 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index f95bcc9a..9c80dd4f 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -61,7 +61,7 @@ A[1, :] # all time periods for f(t) A.t ``` """ -mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, AbstractDiffEqArray}} <: +mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <: AbstractDiffEqArray{T, N, A} u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}} t::B @@ -264,73 +264,15 @@ get_discretes(x) = getfield(x, :discretes) SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries() function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B, - F, S, D}}) where {T, N, A, B, F, S, D <: AbstractDiffEqArray} + F, S, D}}) where {T, N, A, B, F, S, D <: ParameterIndexingProxy} Timeseries() end SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys - -_parameter_timeseries(::Timeseries, A::AbstractDiffEqArray) = get_discretes(A).t -_parameter_timeseries(::NotTimeseries, A::AbstractDiffEqArray) = [0] - -function SymbolicIndexingInterface.parameter_timeseries(A::AbstractDiffEqArray) - _parameter_timeseries(is_parameter_timeseries(A), A) -end - -function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray, i) - ps = parameter_values(A) - discretes = get_discretes(A) - return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[i]) -end -function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray) - ps = parameter_values(A) - discretes = get_discretes(A) - return SciMLStructures.replace.((SciMLStructures.Discrete(),), (ps,), discretes.u) -end -_parameter_values_at_time(::NotTimeseries, A::DiffEqArray, _...) = parameter_values(A) - -function SymbolicIndexingInterface.parameter_values_at_time(A::DiffEqArray, i...) - return _parameter_values_at_time(is_parameter_timeseries(A), A, i...) -end - -function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray, i) - ps = parameter_values(A) - discretes = get_discretes(A) - t = A.t[i] - idx = searchsortedfirst(discretes.t, t; lt = <=) - if idx == firstindex(discretes.t) - error("This should never happen: there is no discrete parameter value before the current time") - end - return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[idx - 1]) - -end -function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray) - ps = parameter_values(A) - discretes = get_discretes(A) - ps_arr = typeof(ps)[] - sizehint!(ps_arr, length(A.t)) - - if first(A.t) < first(discretes.t) - error("This should never happen: there is no discrete parameter value before the current time") - end - - A_idx = firstindex(A.t) - while checkbounds(Bool, A.t, A_idx) - disc_idx = searchsortedfirst(discretes.t, A.t[A_idx]; lt = <=) - newps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[disc_idx - 1]) - while checkbounds(Bool, A.t, A_idx) && A.t[A_idx] <= discretes.t[disc_idx] - push!(ps_arr, newps) - A_idx += 1 - end - end - return ps_arr -end - -_parameter_values_at_state_time(::NotTimeseries, A::AbstractDiffEqArray, _...) = parameter_values(A) -function SymbolicIndexingInterface.parameter_values_at_state_time(A::AbstractDiffEqArray, i...) - return _parameter_values_at_time(is_parameter_timeseries(A), A, i...) +function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::AbstractDiffEqArray) + return get_discretes(A) end Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A)) @@ -440,39 +382,18 @@ end # Symbolic Indexing Methods for (symtype, elsymtype, valtype, errcheck) in [ - (ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), - (ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))), + (ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))), + (ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))), (NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray}, - :(all(x -> is_parameter(A, x), sym))), + :(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))), ] -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype) - if $errcheck - throw(ParameterIndexingError(sym)) - end - getu(A, sym)(A) -end -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype, arg) - if $errcheck - throw(ParameterIndexingError(sym)) - end - getu(A, sym)(A, arg) -end -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}}) - if $errcheck - throw(ParameterIndexingError(sym)) - end - getu(A, sym).((A,), arg) -end -@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, - ::$elsymtype, sym::$valtype, ::Colon) - if $errcheck - throw(ParameterIndexingError(sym)) + @eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, + ::$elsymtype, sym::$valtype, arg...) + if $errcheck + throw(ParameterIndexingError(sym)) + end + getu(A, sym)(A, arg...) end - getu(A, sym)(A) -end end Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,