Skip to content

Commit

Permalink
fixup! feat: add parameter timeseries support to AbstractDiffEqArray
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 21, 2024
1 parent 6382fab commit 72ceb82
Showing 1 changed file with 13 additions and 92 deletions.
105 changes: 13 additions & 92 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ A[1, :] # all time periods for f(t)
A.t
```
"""
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, AbstractDiffEqArray}} <:
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
AbstractDiffEqArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
t::B
Expand Down Expand Up @@ -264,73 +264,15 @@ get_discretes(x) = getfield(x, :discretes)

Check warning on line 264 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L264

Added line #L264 was not covered by tests
SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries()
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B,
F, S, D}}) where {T, N, A, B, F, S, D <: AbstractDiffEqArray}
F, S, D}}) where {T, N, A, B, F, S, D <: ParameterIndexingProxy}
Timeseries()

Check warning on line 268 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L267-L268

Added lines #L267 - L268 were not covered by tests
end
SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u
SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t
SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p
SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys

_parameter_timeseries(::Timeseries, A::AbstractDiffEqArray) = get_discretes(A).t
_parameter_timeseries(::NotTimeseries, A::AbstractDiffEqArray) = [0]

function SymbolicIndexingInterface.parameter_timeseries(A::AbstractDiffEqArray)
_parameter_timeseries(is_parameter_timeseries(A), A)
end

function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray, i)
ps = parameter_values(A)
discretes = get_discretes(A)
return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[i])
end
function _parameter_values_at_time(::Timeseries, A::AbstractDiffEqArray)
ps = parameter_values(A)
discretes = get_discretes(A)
return SciMLStructures.replace.((SciMLStructures.Discrete(),), (ps,), discretes.u)
end
_parameter_values_at_time(::NotTimeseries, A::DiffEqArray, _...) = parameter_values(A)

function SymbolicIndexingInterface.parameter_values_at_time(A::DiffEqArray, i...)
return _parameter_values_at_time(is_parameter_timeseries(A), A, i...)
end

function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray, i)
ps = parameter_values(A)
discretes = get_discretes(A)
t = A.t[i]
idx = searchsortedfirst(discretes.t, t; lt = <=)
if idx == firstindex(discretes.t)
error("This should never happen: there is no discrete parameter value before the current time")
end
return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[idx - 1])

end
function _parameter_values_at_state_time(::Timeseries, A::AbstractDiffEqArray)
ps = parameter_values(A)
discretes = get_discretes(A)
ps_arr = typeof(ps)[]
sizehint!(ps_arr, length(A.t))

if first(A.t) < first(discretes.t)
error("This should never happen: there is no discrete parameter value before the current time")
end

A_idx = firstindex(A.t)
while checkbounds(Bool, A.t, A_idx)
disc_idx = searchsortedfirst(discretes.t, A.t[A_idx]; lt = <=)
newps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[disc_idx - 1])
while checkbounds(Bool, A.t, A_idx) && A.t[A_idx] <= discretes.t[disc_idx]
push!(ps_arr, newps)
A_idx += 1
end
end
return ps_arr
end

_parameter_values_at_state_time(::NotTimeseries, A::AbstractDiffEqArray, _...) = parameter_values(A)
function SymbolicIndexingInterface.parameter_values_at_state_time(A::AbstractDiffEqArray, i...)
return _parameter_values_at_time(is_parameter_timeseries(A), A, i...)
function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::AbstractDiffEqArray)
return get_discretes(A)
end

Check warning on line 276 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L274-L276

Added lines #L274 - L276 were not covered by tests

Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A))
Expand Down Expand Up @@ -440,39 +382,18 @@ end

# Symbolic Indexing Methods
for (symtype, elsymtype, valtype, errcheck) in [
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
:(all(x -> is_parameter(A, x), sym))),
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))),
]
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype)
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym)(A)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg)
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym)(A, arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym).((A,), arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, ::Colon)
if $errcheck
throw(ParameterIndexingError(sym))
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,

Check warning on line 390 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L390

Added line #L390 was not covered by tests
::$elsymtype, sym::$valtype, arg...)
if $errcheck

Check warning on line 392 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L392

Added line #L392 was not covered by tests
throw(ParameterIndexingError(sym))
end
getu(A, sym)(A, arg...)

Check warning on line 395 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L394-L395

Added lines #L394 - L395 were not covered by tests
end
getu(A, sym)(A)
end
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
Expand Down

0 comments on commit 72ceb82

Please sign in to comment.