Skip to content

Commit

Permalink
feat: add parameter timeseries support to AbstractDiffEqArray
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 3, 2024
1 parent eacbe3f commit 6382fab
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 15 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -22,9 +23,9 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
Expand Down Expand Up @@ -54,6 +55,7 @@ Random = "1"
RecipesBase = "1.1"
ReverseDiff = "1.15"
SafeTestsets = "0.1"
SciMLStructures = "1.1"
SparseArrays = "1.10"
StaticArrays = "1.6"
StaticArraysCore = "1.4"
Expand Down
107 changes: 93 additions & 14 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ A[1, :] # all time periods for f(t)
A.t
```
"""
mutable struct DiffEqArray{T, N, A, B, F, S} <: AbstractDiffEqArray{T, N, A}
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, AbstractDiffEqArray}} <:
AbstractDiffEqArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
t::B
p::F
sys::S
discretes::D
end
### Abstract Interface
struct AllObserved
Expand Down Expand Up @@ -172,29 +174,32 @@ function DiffEqArray(vec::AbstractVector{T},
ts::AbstractVector,
::NTuple{N, Int},
p = nothing,
sys = nothing) where {T, N}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec,
sys = nothing; discretes = nothing) where {T, N}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,

Check warning on line 178 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L178

Added line #L178 was not covered by tests
ts,
p,
sys)
sys,
discretes)
end

# ambiguity resolution
function DiffEqArray(vec::AbstractVector{VT},
ts::AbstractVector,
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec,
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec,

Check warning on line 189 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L189

Added line #L189 was not covered by tests
ts,
nothing,
nothing,
nothing)
end
function DiffEqArray(vec::AbstractVector{VT},
ts::AbstractVector,
::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec,
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,

Check warning on line 198 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L198

Added line #L198 was not covered by tests
ts,
p,
nothing)
nothing,
discretes)
end
# Assume that the first element is representative of all other elements

Expand All @@ -204,7 +209,8 @@ function DiffEqArray(vec::AbstractVector,
sys = nothing;
variables = nothing,
parameters = nothing,
independent_variables = nothing)
independent_variables = nothing,
discretes = nothing)
sys = something(sys,
SymbolCache(something(variables, []),
something(parameters, []),
Expand All @@ -217,11 +223,13 @@ function DiffEqArray(vec::AbstractVector,
typeof(vec),
typeof(ts),
typeof(p),
typeof(sys)
typeof(sys),
typeof(discretes)
}(vec,
ts,
p,
sys)
sys,
discretes)
end

function DiffEqArray(vec::AbstractVector{VT},
Expand All @@ -230,7 +238,8 @@ function DiffEqArray(vec::AbstractVector{VT},
sys = nothing;
variables = nothing,
parameters = nothing,
independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
independent_variables = nothing,
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
sys = something(sys,
SymbolCache(something(variables, []),
something(parameters, []),
Expand All @@ -241,19 +250,89 @@ function DiffEqArray(vec::AbstractVector{VT},
typeof(vec),
typeof(ts),
typeof(p),
typeof(sys)
typeof(sys),
typeof(discretes),
}(vec,
ts,
p,
sys)
sys,
discretes)
end

has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes)
get_discretes(x) = getfield(x, :discretes)

Check warning on line 263 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L262-L263

Added lines #L262 - L263 were not covered by tests

SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries()
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B,

Check warning on line 266 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L266

Added line #L266 was not covered by tests
F, S, D}}) where {T, N, A, B, F, S, D <: AbstractDiffEqArray}
Timeseries()

Check warning on line 268 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L268

Added line #L268 was not covered by tests
end
SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u
SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t
SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p
SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys

_parameter_timeseries(::Timeseries, A::AbstractDiffEqArray) = get_discretes(A).t
_parameter_timeseries(::NotTimeseries, A::AbstractDiffEqArray) = [0]

Check warning on line 276 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L275-L276

Added lines #L275 - L276 were not covered by tests

function SymbolicIndexingInterface.parameter_timeseries(A::AbstractDiffEqArray)
_parameter_timeseries(is_parameter_timeseries(A), A)

Check warning on line 279 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L278-L279

Added lines #L278 - L279 were not covered by tests
end

function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray, i)
ps = parameter_values(A)
discretes = get_discretes(A)
return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[i])

Check warning on line 285 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L282-L285

Added lines #L282 - L285 were not covered by tests
end
function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray)
ps = parameter_values(A)
discretes = get_discretes(A)
return SciMLStructures.replace.((SciMLStructures.Discrete(),), (ps,), discretes.u)

Check warning on line 290 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L287-L290

Added lines #L287 - L290 were not covered by tests
end
_parameter_values_at_time(::NotTimeseries, A::DiffEqArray, _...) = parameter_values(A)

Check warning on line 292 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L292

Added line #L292 was not covered by tests

function SymbolicIndexingInterface.parameter_values_at_time(A::DiffEqArray, i...)
return _parameter_values_at_time(is_parameter_timeseries(A), A, i...)

Check warning on line 295 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L294-L295

Added lines #L294 - L295 were not covered by tests
end

function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray, i)
ps = parameter_values(A)
discretes = get_discretes(A)
t = A.t[i]
idx = searchsortedfirst(discretes.t, t; lt = <=)
if idx == firstindex(discretes.t)
error("This should never happen: there is no discrete parameter value before the current time")

Check warning on line 304 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L298-L304

Added lines #L298 - L304 were not covered by tests
end
return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[idx - 1])

Check warning on line 306 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L306

Added line #L306 was not covered by tests

end
function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray)
ps = parameter_values(A)
discretes = get_discretes(A)
ps_arr = typeof(ps)[]
sizehint!(ps_arr, length(A.t))

Check warning on line 313 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L309-L313

Added lines #L309 - L313 were not covered by tests

if first(A.t) < first(discretes.t)
error("This should never happen: there is no discrete parameter value before the current time")

Check warning on line 316 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L315-L316

Added lines #L315 - L316 were not covered by tests
end

A_idx = firstindex(A.t)
while checkbounds(Bool, A.t, A_idx)
disc_idx = searchsortedfirst(discretes.t, A.t[A_idx]; lt = <=)
newps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[disc_idx - 1])
while checkbounds(Bool, A.t, A_idx) && A.t[A_idx] <= discretes.t[disc_idx]
push!(ps_arr, newps)
A_idx += 1
end
end
return ps_arr

Check warning on line 328 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L319-L328

Added lines #L319 - L328 were not covered by tests
end

_parameter_values_at_state_time(::NotTimeseries, A::AbstractDiffEqArray, _...) = parameter_values(A)
function SymbolicIndexingInterface.parameter_values_at_state_time(A::AbstractDiffEqArray, i...)
return _parameter_values_at_time(is_parameter_timeseries(A), A, i...)

Check warning on line 333 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L331-L333

Added lines #L331 - L333 were not covered by tests
end

Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A))
Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()

Expand Down

0 comments on commit 6382fab

Please sign in to comment.