From 6382fab3d83a0b126b44765dd39e255474115d7c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 3 May 2024 20:12:45 +0530 Subject: [PATCH] feat: add parameter timeseries support to `AbstractDiffEqArray` --- Project.toml | 4 +- src/vector_of_array.jl | 107 +++++++++++++++++++++++++++++++++++------ 2 files changed, 96 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index d9b98521..bac22b7a 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -22,9 +23,9 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" @@ -54,6 +55,7 @@ Random = "1" RecipesBase = "1.1" ReverseDiff = "1.15" SafeTestsets = "0.1" +SciMLStructures = "1.1" SparseArrays = "1.10" StaticArrays = "1.6" StaticArraysCore = "1.4" diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 7a402c0f..f95bcc9a 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -61,11 +61,13 @@ A[1, :] # all time periods for f(t) A.t ``` """ -mutable struct DiffEqArray{T, N, A, B, F, S} <: AbstractDiffEqArray{T, N, A} +mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, AbstractDiffEqArray}} <: + AbstractDiffEqArray{T, N, A} u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}} t::B p::F sys::S + discretes::D end ### Abstract Interface struct AllObserved @@ -172,29 +174,32 @@ function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p = nothing, - sys = nothing) where {T, N} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec, + sys = nothing; discretes = nothing) where {T, N} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, ts, p, - sys) + sys, + discretes) end # ambiguity resolution function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, ::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec, + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec, ts, nothing, + nothing, nothing) end function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, - ::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec, + ::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec, ts, p, - nothing) + nothing, + discretes) end # Assume that the first element is representative of all other elements @@ -204,7 +209,8 @@ function DiffEqArray(vec::AbstractVector, sys = nothing; variables = nothing, parameters = nothing, - independent_variables = nothing) + independent_variables = nothing, + discretes = nothing) sys = something(sys, SymbolCache(something(variables, []), something(parameters, []), @@ -217,11 +223,13 @@ function DiffEqArray(vec::AbstractVector, typeof(vec), typeof(ts), typeof(p), - typeof(sys) + typeof(sys), + typeof(discretes) }(vec, ts, p, - sys) + sys, + discretes) end function DiffEqArray(vec::AbstractVector{VT}, @@ -230,7 +238,8 @@ function DiffEqArray(vec::AbstractVector{VT}, sys = nothing; variables = nothing, parameters = nothing, - independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} + independent_variables = nothing, + discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} sys = something(sys, SymbolCache(something(variables, []), something(parameters, []), @@ -241,19 +250,89 @@ function DiffEqArray(vec::AbstractVector{VT}, typeof(vec), typeof(ts), typeof(p), - typeof(sys) + typeof(sys), + typeof(discretes), }(vec, ts, p, - sys) + sys, + discretes) end +has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes) +get_discretes(x) = getfield(x, :discretes) + SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries() +function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B, + F, S, D}}) where {T, N, A, B, F, S, D <: AbstractDiffEqArray} + Timeseries() +end SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys +_parameter_timeseries(::Timeseries, A::AbstractDiffEqArray) = get_discretes(A).t +_parameter_timeseries(::NotTimeseries, A::AbstractDiffEqArray) = [0] + +function SymbolicIndexingInterface.parameter_timeseries(A::AbstractDiffEqArray) + _parameter_timeseries(is_parameter_timeseries(A), A) +end + +function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray, i) + ps = parameter_values(A) + discretes = get_discretes(A) + return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[i]) +end +function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray) + ps = parameter_values(A) + discretes = get_discretes(A) + return SciMLStructures.replace.((SciMLStructures.Discrete(),), (ps,), discretes.u) +end +_parameter_values_at_time(::NotTimeseries, A::DiffEqArray, _...) = parameter_values(A) + +function SymbolicIndexingInterface.parameter_values_at_time(A::DiffEqArray, i...) + return _parameter_values_at_time(is_parameter_timeseries(A), A, i...) +end + +function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray, i) + ps = parameter_values(A) + discretes = get_discretes(A) + t = A.t[i] + idx = searchsortedfirst(discretes.t, t; lt = <=) + if idx == firstindex(discretes.t) + error("This should never happen: there is no discrete parameter value before the current time") + end + return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[idx - 1]) + +end +function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray) + ps = parameter_values(A) + discretes = get_discretes(A) + ps_arr = typeof(ps)[] + sizehint!(ps_arr, length(A.t)) + + if first(A.t) < first(discretes.t) + error("This should never happen: there is no discrete parameter value before the current time") + end + + A_idx = firstindex(A.t) + while checkbounds(Bool, A.t, A_idx) + disc_idx = searchsortedfirst(discretes.t, A.t[A_idx]; lt = <=) + newps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[disc_idx - 1]) + while checkbounds(Bool, A.t, A_idx) && A.t[A_idx] <= discretes.t[disc_idx] + push!(ps_arr, newps) + A_idx += 1 + end + end + return ps_arr +end + +_parameter_values_at_state_time(::NotTimeseries, A::AbstractDiffEqArray, _...) = parameter_values(A) +function SymbolicIndexingInterface.parameter_values_at_state_time(A::AbstractDiffEqArray, i...) + return _parameter_values_at_time(is_parameter_timeseries(A), A, i...) +end + Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A)) Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()