From d4e7d2ebcb9546b7d9a818daa7ce8134c9b1ec57 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 6 Mar 2024 18:34:53 +0530 Subject: [PATCH 1/5] feat: initial idea for discrete indexing --- src/SymbolicIndexingInterface.jl | 8 +-- src/parameter_indexing.jl | 117 +++++++++++++++++++++++++++---- src/state_indexing.jl | 55 +++------------ src/trait.jl | 42 +++++++++++ 4 files changed, 159 insertions(+), 63 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 4edcb21..34d57fd 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -1,6 +1,7 @@ module SymbolicIndexingInterface -export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname +export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname, + Timeseries, NotTimeseries, is_timeseries include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, @@ -14,11 +15,10 @@ include("interface.jl") export SymbolCache include("symbol_cache.jl") -export parameter_values, set_parameter!, getp, setp +export parameter_values, set_parameter!, parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp, setp include("parameter_indexing.jl") -export Timeseries, - NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu +export state_values, set_state!, current_time, getu, setu include("state_indexing.jl") export ParameterIndexingProxy diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index bc4b660..d214fa3 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -15,6 +15,56 @@ parameter_values(arr::AbstractArray) = arr parameter_values(arr::AbstractArray, i) = arr[i] parameter_values(prob, i) = parameter_values(parameter_values(prob), i) +""" + parameter_values_at_time(p, i) + +Return an indexable collection containing the value of all parameters in `p` at time index +`i`. This is useful when parameter values change during the simulation +(such as through callbacks) and their values are saved. `i` is the time index in the +timeseries formed by these changing parameter values, obtained using +[`parameter_timeseries`](@ref). + +By default, this function returns `parameter_values(p)` regardless of `i`, and only needs +to be specialized for timeseries objects where parameter values are not constant at all +times. The resultant object should be indexable using [`parameter_values`](@ref). + +If this function is implemented, [`parameter_values_at_state_time`](@ref) must be +implemented for [`getu`](@ref) to work correctly. +""" +function parameter_values_at_time end +parameter_values_at_time(p, i) = parameter_values(p) + +""" + parameter_values_at_state_time(p, i) + +Return an indexable collection containing the value of all parameters in `p` at time +index `i`. This is useful when parameter values change during the simulation (such as +through callbacks) and their values are saved. `i` is the time index in the timeseries +formed by dependent variables (as opposed to the timeseries of the parameters, as in +[`parameter_values_at_time`](@ref)). + +By default, this function returns `parameter_values(p)` regardless of `i`, and only needs +to be specialized for timeseries objects where parameter values are not constant at +all times. The resultant object should be indexable using [`parameter_values`](@ref). + +If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for +[`getp`](@ref) to work correctly. +""" +function parameter_values_at_state_time end +parameter_values_at_state_time(p, i) = parameter_values(p) + +""" + parameter_timeseries(p) + +Return an iterable of time steps at which the parameter values are saved. This is only +required for objects where `is_timeseries(p) === Timeseries()` and the parameter values +change during the simulation (such as through callbacks). By default, this returns `[0]`. + +See also: [`parameter_values_at_time`](@ref). +""" +function parameter_timeseries end +parameter_timeseries(_) = [0] + """ set_parameter!(sys, val, idx) @@ -55,18 +105,30 @@ function getp(sys, p) end function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) - return function getter(sol) - return parameter_values(sol, p) + _getter = let p = p + function _getter(::NotTimeseries, prob) + parameter_values(prob, p) + end + function _getter(::Timeseries, prob) + parameter_values(prob, p) + end + function _getter(::Timeseries, prob, i) + parameter_values(parameter_values_at_time(prob, i), p) + end + function _getter(::Timeseries, prob, ::Colon) + parameter_values.((parameter_values_at_time(prob, i) for i in eachindex(parameter_timeseries(prob))), (p,)) + end + end + return let _getter = _getter + function getter(prob, args...) + return _getter(is_timeseries(prob), prob, args...) + end end end function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return let idx = idx - function getter(sol) - return parameter_values(sol, idx) - end - end + return getp(sys, idx) end for (t1, t2) in [ @@ -77,16 +139,45 @@ for (t1, t2) in [ @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) getters = getp.((sys,), p) - return let getters = getters - function getter(sol) - map(g -> g(sol), getters) + _getter = return let getters = getters + function _getter(::NotTimeseries, prob) + map(g -> g(prob), getters) + end + function _getter(::Timeseries, prob) + map(g -> g(prob), getters) + end + function _getter(::Timeseries, prob, i) + map(g -> g(prob, i), getters) end - function getter(buffer, sol) - for (i, g) in zip(eachindex(buffer), getters) - buffer[i] = g(sol) + function _getter(::Timeseries, prob, ::Colon) + [map(g -> g(prob, i), getters) for i in eachindex(parameter_timeseries(prob))] + end + function _getter(buffer, ::NotTimeseries, prob) + map!(g -> g(prob), buffer, getters) + end + function _getter(buffer, ::Timeseries, prob) + map!(g -> g(prob), buffer, getters) + end + function _getter(buffer, ::Timeseries, prob, i) + map!(g -> g(prob, i), buffer, getters) + end + function _getter(buffer, ::Timeseries, prob, ::Colon) + for (bufi, tsi) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob))) + map!(g -> g(prob, tsi), buffer[bufi], getters) end buffer end + _getter + end + + return let _getter = _getter + function getter(prob, i...) + return _getter(is_timeseries(prob), prob, i...) + end + function getter(buffer::AbstractArray, prob, i...) + return _getter(buffer, is_timeseries(prob), prob, i...) + end + getter end end end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index e35c4ca..54eef94 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -1,42 +1,3 @@ -abstract type IsTimeseriesTrait end - -""" - struct Timeseries <: IsTimeseriesTrait end - -Trait indicating a type contains timeseries data. This affects the behaviour of -functions such as [`state_values`](@ref) and [`current_time`](@ref). - -See also: [`NotTimeseries`](@ref), [`is_timeseries`](@ref) -""" -struct Timeseries <: IsTimeseriesTrait end - -""" - struct NotTimeseries <: IsTimeseriesTrait end - -Trait indicating a type does not contain timeseries data. This affects the behaviour -of functions such as [`state_values`](@ref) and [`current_time`](@ref). Note that -if a type is `NotTimeseries` this only implies that it does not _store_ timeseries -data. It may still be time-dependent. For example, an `ODEProblem` only stores -the initial state of a system, so it is `NotTimeseries`, but still time-dependent. -This is the default trait variant for all types. - -See also: [`Timeseries`](@ref), [`is_timeseries`](@ref) -""" -struct NotTimeseries <: IsTimeseriesTrait end - -""" - is_timeseries(x) = is_timeseries(typeof(x)) - is_timeseries(::Type) - -Get the timeseries trait of a type. Defaults to [`NotTimeseries`](@ref) for all types. - -See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref) -""" -function is_timeseries end - -is_timeseries(x) = is_timeseries(typeof(x)) -is_timeseries(::Type) = NotTimeseries() - """ state_values(p) state_values(p, i) @@ -149,13 +110,14 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) fn = observed(sys, sym) if is_time_dependent(sys) function _getter2(::Timeseries, prob) + curtime = current_time(prob) return fn.(state_values(prob), - (parameter_values(prob),), - current_time(prob)) + (parameter_values_at_state_time(prob, i) for i in eachindex(curtime)), + curtime) end function _getter2(::Timeseries, prob, i) return fn(state_values(prob, i), - parameter_values(prob), + parameter_values_at_state_time(prob, i), current_time(prob, i)) end function _getter2(::NotTimeseries, prob) @@ -222,18 +184,19 @@ for (t1, t2) in [ obs(state_values(prob), parameter_values(prob), current_time(prob)) end function _getter2a(::Timeseries, prob) - obs.(state_values(prob), (parameter_values(prob),), - current_time(prob)) + curtime = current_time(prob) + obs.(state_values(prob), (parameter_values_at_state_time(prob, i) for i in eachindex(curtime)), + curtime) end function _getter2a(::Timeseries, prob, i) obs(state_values(prob, i), - parameter_values(prob), + parameter_values_at_state_time(prob, i), current_time(prob, i)) end _getter2a end else - let obs = obs, is_tuple = sym isa Tuple + let obs = obs function _getter2b(::NotTimeseries, prob) obs(state_values(prob), parameter_values(prob)) end diff --git a/src/trait.jl b/src/trait.jl index 4235114..fbd0f9d 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -60,3 +60,45 @@ hasname(::Any) = false Get the name of a symbolic variable as a `Symbol` """ function getname end + +############ IsTimeseriesTrait + +abstract type IsTimeseriesTrait end + +""" + struct Timeseries <: IsTimeseriesTrait end + +Trait indicating a type contains timeseries data. This affects the behaviour of +functions such as [`state_values`](@ref) and [`current_time`](@ref). + +See also: [`NotTimeseries`](@ref), [`is_timeseries`](@ref) +""" +struct Timeseries <: IsTimeseriesTrait end + +""" + struct NotTimeseries <: IsTimeseriesTrait end + +Trait indicating a type does not contain timeseries data. This affects the behaviour +of functions such as [`state_values`](@ref) and [`current_time`](@ref). Note that +if a type is `NotTimeseries` this only implies that it does not _store_ timeseries +data. It may still be time-dependent. For example, an `ODEProblem` only stores +the initial state of a system, so it is `NotTimeseries`, but still time-dependent. +This is the default trait variant for all types. + +See also: [`Timeseries`](@ref), [`is_timeseries`](@ref) +""" +struct NotTimeseries <: IsTimeseriesTrait end + +""" + is_timeseries(x) = is_timeseries(typeof(x)) + is_timeseries(::Type) + +Get the timeseries trait of a type. Defaults to [`NotTimeseries`](@ref) for all types. + +See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref) +""" +function is_timeseries end + +is_timeseries(x) = is_timeseries(typeof(x)) +is_timeseries(::Type) = NotTimeseries() + From a77908a9b5b3b2754db28bac822774c5f204ec67 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 7 Mar 2024 16:44:32 +0530 Subject: [PATCH 2/5] feat: make getp type-stable, make .ps type-stable, add tests for new indexing --- src/parameter_indexing.jl | 80 ++++++++++++++++++++------------- src/parameter_indexing_proxy.jl | 4 +- test/parameter_indexing_test.jl | 78 ++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 34 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index d214fa3..7dc160c 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -105,30 +105,35 @@ function getp(sys, p) end function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) - _getter = let p = p + return let p = p function _getter(::NotTimeseries, prob) parameter_values(prob, p) end function _getter(::Timeseries, prob) parameter_values(prob, p) end - function _getter(::Timeseries, prob, i) - parameter_values(parameter_values_at_time(prob, i), p) + function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex}) + parameter_values(parameter_values_at_time(prob, only(to_indices(parameter_timeseries(prob), (i,)))), p) end - function _getter(::Timeseries, prob, ::Colon) - parameter_values.((parameter_values_at_time(prob, i) for i in eachindex(parameter_timeseries(prob))), (p,)) + function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon}) + parameter_values.(parameter_values_at_time.((prob,), (j for j in only(to_indices(parameter_timeseries(prob), (i,))))), p) end - end - return let _getter = _getter - function getter(prob, args...) - return _getter(is_timeseries(prob), prob, args...) + function _getter(::Timeseries, prob, i) + parameter_values.(parameter_values_at_time.((prob,), i), (p,)) + end + getter = let _getter = _getter + function getter(prob, args...) + return _getter(is_timeseries(prob), prob, args...) + end end + getter end end function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return getp(sys, idx) + return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, sys, NotSymbolic(), NotSymbolic(), idx) + return _getp(sys, NotSymbolic(), NotSymbolic(), idx) end for (t1, t2) in [ @@ -139,43 +144,54 @@ for (t1, t2) in [ @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) getters = getp.((sys,), p) - _getter = return let getters = getters + return let getters = getters function _getter(::NotTimeseries, prob) map(g -> g(prob), getters) end function _getter(::Timeseries, prob) map(g -> g(prob), getters) end - function _getter(::Timeseries, prob, i) + function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex}) map(g -> g(prob, i), getters) end - function _getter(::Timeseries, prob, ::Colon) - [map(g -> g(prob, i), getters) for i in eachindex(parameter_timeseries(prob))] - end - function _getter(buffer, ::NotTimeseries, prob) - map!(g -> g(prob), buffer, getters) + function _getter(::Timeseries, prob, i) + [map(g -> g(prob, j), getters) for j in only(to_indices(parameter_timeseries(prob), (i,)))] end - function _getter(buffer, ::Timeseries, prob) - map!(g -> g(prob), buffer, getters) + function _getter!(buffer, ::NotTimeseries, prob) + for (g, bufi) in zip(getters, eachindex(buffer)) + buffer[bufi] = g(prob) + end + buffer end - function _getter(buffer, ::Timeseries, prob, i) - map!(g -> g(prob, i), buffer, getters) + function _getter!(buffer, ::Timeseries, prob) + for (g, bufi) in zip(getters, eachindex(buffer)) + buffer[bufi] = g(prob) + end + buffer end - function _getter(buffer, ::Timeseries, prob, ::Colon) - for (bufi, tsi) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob))) - map!(g -> g(prob, tsi), buffer[bufi], getters) + function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex}) + for (g, bufi) in zip(getters, eachindex(buffer)) + buffer[bufi] = g(prob, i) end buffer end - _getter - end - - return let _getter = _getter - function getter(prob, i...) - return _getter(is_timeseries(prob), prob, i...) + function _getter!(buffer, ::Timeseries, prob, i) + for (bufi, tsi) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,)))) + for (g, bufj) in zip(getters, eachindex(buffer[bufi])) + buffer[bufi][bufj] = g(prob, tsi) + end + end + buffer end - function getter(buffer::AbstractArray, prob, i...) - return _getter(buffer, is_timeseries(prob), prob, i...) + _getter, _getter! + getter = let _getter = _getter, _getter! = _getter! + function getter(prob, i...) + return _getter(is_timeseries(prob), prob, i...) + end + function getter(buffer::AbstractArray, prob, i...) + return _getter!(buffer, is_timeseries(prob), prob, i...) + end + getter end getter end diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index 0b92587..ab365ed 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -10,8 +10,8 @@ struct ParameterIndexingProxy{T} wrapped::T end -function Base.getindex(p::ParameterIndexingProxy, idx) - return getp(p.wrapped, idx)(p.wrapped) +function Base.getindex(p::ParameterIndexingProxy, idx, args...) + getp(p.wrapped, idx)(p.wrapped, args...) end function Base.setindex!(p::ParameterIndexingProxy, val, idx) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index adac52f..dc6188a 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -6,6 +6,9 @@ struct FakeIntegrator{S, P} p::P end +function Base.getproperty(fi::FakeIntegrator, s::Symbol) + s === :ps ? ParameterIndexingProxy(fi) : getfield(fi, s) +end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p @@ -13,6 +16,7 @@ sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) p = [1.0, 2.0, 3.0] fi = FakeIntegrator(sys, copy(p)) new_p = [4.0, 5.0, 6.0] +@test parameter_timeseries(fi) == [0] for (sym, oldval, newval, check_inference) in [ (:a, p[1], new_p[1], true), (1, p[1], new_p[1], true), @@ -33,6 +37,7 @@ for (sym, oldval, newval, check_inference) in [ if check_inference @inferred get(fi) end + @test get(fi) == fi.ps[sym] @test get(fi) == oldval if check_inference @inferred set!(fi, newval) @@ -43,6 +48,11 @@ for (sym, oldval, newval, check_inference) in [ set!(fi, oldval) @test get(fi) == oldval + fi.ps[sym] = newval + @test get(fi) == newval + fi.ps[sym] = oldval + @test get(fi) == oldval + if check_inference @inferred get(p) end @@ -68,3 +78,71 @@ for (sym, val) in [ @inferred get(buffer, fi) @test buffer == val end + +struct FakeSolution + sys::SymbolCache + u::Vector{Vector{Float64}} + t::Vector{Float64} + p::Vector{Vector{Float64}} + pt::Vector{Float64} +end + +function Base.getproperty(fs::FakeSolution, s::Symbol) + s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s) +end +SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys +SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p[end] +SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[end][i] +function SymbolicIndexingInterface.parameter_values_at_time(fs::FakeSolution, t) + fs.p[t] +end +function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSolution, t) + ptind = searchsortedfirst(fs.pt, fs.t[t]; lt = <=) + fs.p[ptind - 1] +end +SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution) = fs.pt +SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries() +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +fs = FakeSolution( + sys, + [i * ones(3) for i in 1:5], + [0.2i for i in 1:5], + [2i * ones(3) for i in 1:10], + [0.1i for i in 1:10], +) +ps = fs.p +p = fs.p[end] +avals = getindex.(ps, 1) +bvals = getindex.(ps, 2) +cvals = getindex.(ps, 3) +@test parameter_timeseries(fs) == fs.pt +for (sym, val, arrval, check_inference) in [ + (:a, p[1], avals, true), + (1, p[1], avals, true), + ([:a, :b], p[1:2], vcat.(avals, bvals), true), + (1:2, p[1:2], vcat.(avals, bvals), true), + ((1, 2), Tuple(p[1:2]), tuple.(avals, bvals), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), + ((:a, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true), + ([1, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), + ((1, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true) +] + get = getp(sys, sym) + if check_inference + @inferred get(fs) + end + @test get(fs) == fs.ps[sym] + @test get(fs) == val + + for sub_inds in [:, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))] + if check_inference + @inferred get(fs, sub_inds) + end + @test get(fs, sub_inds) == fs.ps[sym, sub_inds] + @test get(fs, sub_inds) == arrval[sub_inds] + end +end From d4a5513f5b30a6ef5cd277154d1987593d15fd42 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 7 Mar 2024 16:51:30 +0530 Subject: [PATCH 3/5] refactor: format --- src/SymbolicIndexingInterface.jl | 3 ++- src/parameter_indexing.jl | 19 ++++++++++++++----- src/state_indexing.jl | 4 +++- src/trait.jl | 1 - test/parameter_indexing_test.jl | 11 +++++++---- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 34d57fd..a18e080 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -15,7 +15,8 @@ include("interface.jl") export SymbolCache include("symbol_cache.jl") -export parameter_values, set_parameter!, parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp, setp +export parameter_values, set_parameter!, parameter_values_at_time, + parameter_values_at_state_time, parameter_timeseries, getp, setp include("parameter_indexing.jl") export state_values, set_state!, current_time, getu, setu diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 7dc160c..3161781 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -113,10 +113,16 @@ function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) parameter_values(prob, p) end function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex}) - parameter_values(parameter_values_at_time(prob, only(to_indices(parameter_timeseries(prob), (i,)))), p) + parameter_values( + parameter_values_at_time( + prob, only(to_indices(parameter_timeseries(prob), (i,)))), + p) end function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon}) - parameter_values.(parameter_values_at_time.((prob,), (j for j in only(to_indices(parameter_timeseries(prob), (i,))))), p) + parameter_values.( + parameter_values_at_time.((prob,), + (j for j in only(to_indices(parameter_timeseries(prob), (i,))))), + p) end function _getter(::Timeseries, prob, i) parameter_values.(parameter_values_at_time.((prob,), i), (p,)) @@ -132,7 +138,8 @@ end function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, sys, NotSymbolic(), NotSymbolic(), idx) + return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any}, + sys, NotSymbolic(), NotSymbolic(), idx) return _getp(sys, NotSymbolic(), NotSymbolic(), idx) end @@ -155,7 +162,8 @@ for (t1, t2) in [ map(g -> g(prob, i), getters) end function _getter(::Timeseries, prob, i) - [map(g -> g(prob, j), getters) for j in only(to_indices(parameter_timeseries(prob), (i,)))] + [map(g -> g(prob, j), getters) + for j in only(to_indices(parameter_timeseries(prob), (i,)))] end function _getter!(buffer, ::NotTimeseries, prob) for (g, bufi) in zip(getters, eachindex(buffer)) @@ -176,7 +184,8 @@ for (t1, t2) in [ buffer end function _getter!(buffer, ::Timeseries, prob, i) - for (bufi, tsi) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,)))) + for (bufi, tsi) in zip( + eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,)))) for (g, bufj) in zip(getters, eachindex(buffer[bufi])) buffer[bufi][bufj] = g(prob, tsi) end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 54eef94..7f884e6 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -185,7 +185,9 @@ for (t1, t2) in [ end function _getter2a(::Timeseries, prob) curtime = current_time(prob) - obs.(state_values(prob), (parameter_values_at_state_time(prob, i) for i in eachindex(curtime)), + obs.(state_values(prob), + (parameter_values_at_state_time(prob, i) + for i in eachindex(curtime)), curtime) end function _getter2a(::Timeseries, prob, i) diff --git a/src/trait.jl b/src/trait.jl index fbd0f9d..29c6673 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -101,4 +101,3 @@ function is_timeseries end is_timeseries(x) = is_timeseries(typeof(x)) is_timeseries(::Type) = NotTimeseries() - diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index dc6188a..77a7592 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -108,7 +108,7 @@ fs = FakeSolution( [i * ones(3) for i in 1:5], [0.2i for i in 1:5], [2i * ones(3) for i in 1:10], - [0.1i for i in 1:10], + [0.1i for i in 1:10] ) ps = fs.p p = fs.p[end] @@ -122,11 +122,13 @@ for (sym, val, arrval, check_inference) in [ ([:a, :b], p[1:2], vcat.(avals, bvals), true), (1:2, p[1:2], vcat.(avals, bvals), true), ((1, 2), Tuple(p[1:2]), tuple.(avals, bvals), true), - ([:a, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), + ([:a, [:b, :c]], [p[1], p[2:3]], + [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), ([:a, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), ((:a, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), ((:a, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true), - ([1, [:b, :c]], [p[1], p[2:3]], [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), + ([1, [:b, :c]], [p[1], p[2:3]], + [[i, [j, k]] for (i, j, k) in zip(avals, bvals, cvals)], false), ([1, (:b, :c)], [p[1], (p[2], p[3])], vcat.(avals, tuple.(bvals, cvals)), false), ((1, [:b, :c]), (p[1], p[2:3]), tuple.(avals, vcat.(bvals, cvals)), true), ((1, (:b, :c)), (p[1], (p[2], p[3])), tuple.(avals, tuple.(bvals, cvals)), true) @@ -138,7 +140,8 @@ for (sym, val, arrval, check_inference) in [ @test get(fs) == fs.ps[sym] @test get(fs) == val - for sub_inds in [:, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))] + for sub_inds in [ + :, 3:5, rand(Bool, length(ps)), rand(eachindex(ps)), rand(CartesianIndices(ps))] if check_inference @inferred get(fs, sub_inds) end From 728b89ada28939b551aeba55e05b721ed7a5a04f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 7 Mar 2024 17:21:38 +0530 Subject: [PATCH 4/5] docs: update documentation for new parameter timeseries indexing --- docs/src/api.md | 12 ++++++ docs/src/complete_sii.md | 78 +++++++++++++++++++++++++++++++++++++++ src/parameter_indexing.jl | 7 ++++ 3 files changed, 97 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 3c049c7..47ec075 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -52,6 +52,18 @@ getu setu ``` +### Parameter timeseries + +If a solution object saves a timeseries of parameter values that are updated during the +simulation (such as by callbacks), it must implement the following methods to ensure +correct functioning of [`getu`](@ref) and [`getp`](@ref). + +```@docs +parameter_timeseries +parameter_values_at_time +parameter_values_at_state_time +``` + # Symbolic Trait ```@docs diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index e5d592a..fc62d14 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -338,3 +338,81 @@ end [`hasname`](@ref) is not required to always be `true` for symbolic types. For example, `Symbolics.Num` returns `false` whenever the wrapped value is a number, or an expression. + +## Parameter Timeseries + +If a solution object saves modified parameter values (such as through callbacks) during the +simulation, it must implement [`parameter_timeseries`](@ref), +[`parameter_values_at_time`](@ref) and [`parameter_values_at_state_time`](@ref) for correct +functioning of [`getu`](@ref) and [`getp`](@ref). The following mockup gives an example +of correct implementation of these functions and the indexing syntax they enable. + +```@example param_timeseries +using SymbolicIndexingInterface + +struct ExampleSolution2 + sys::SymbolCache + u::Vector{Vector{Float64}} + t::Vector{Float64} + p::Vector{Vector{Float64}} + pt::Vector{Float64} +end + +# Add the `:ps` property to automatically wrap in `ParameterIndexingProxy` +function Base.getproperty(fs::ExampleSolution2, s::Symbol) + s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s) +end +# Use the contained `SymbolCache` for indexing +SymbolicIndexingInterface.symbolic_container(fs::ExampleSolution2) = fs.sys +# By default, `parameter_values` refers to the last value +SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2) = fs.p[end] +SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2, i) = fs.p[end][i] +# Index into the parameter timeseries vector +function SymbolicIndexingInterface.parameter_values_at_time(fs::ExampleSolution2, t) + fs.p[t] +end +# Find the first index in the parameter timeseries vector with a time smaller +# than the time from the state timeseries, and use that to index the parameter +# timeseries +function SymbolicIndexingInterface.parameter_values_at_state_time(fs::ExampleSolution2, t) + ptind = searchsortedfirst(fs.pt, fs.t[t]; lt = <=) + fs.p[ptind - 1] +end +SymbolicIndexingInterface.parameter_timeseries(fs::ExampleSolution2) = fs.pt +# Mark the object as a `Timeseries` object +SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution2}) = Timeseries() + +``` + +Now we can create an example object and observe the new functionality. Note that +`sol.ps[sym, args...]` is identical to `getp(sol, sym)(sol, args...)`. + +```@example param_timeseries +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +sol = ExampleSolution2( + sys, + [i * ones(3) for i in 1:5], + [0.2i for i in 1:5], + [2i * ones(3) for i in 1:10], + [0.1i for i in 1:10] +) +sol.ps[:a] # returns the value at the last timestep +``` + +```@example param_timeseries +sol.ps[:a, :] # use Colon to fetch the entire parameter timeseries +``` + +```@example param_timeseries +sol.ps[:a, 3] # index at a specific index in the parameter timeseries +``` + +```@example param_timeseries +sol.ps[:a, [3, 6, 8]] # index using arrays +``` + +```@example param_timeseries +idxs = @show rand(Bool, 10) # boolean mask for indexing +sol.ps[:a, idxs] +``` + diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 3161781..248aa1e 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -97,6 +97,13 @@ solution from which the values are obtained. Requires that the integrator or solution implement [`parameter_values`](@ref). This function typically does not need to be implemented, and has a default implementation relying on [`parameter_values`](@ref). + +If the returned function is used on a timeseries object which saves parameter timeseries, it +can be used to index said timeseries. The timeseries object must implement +[`parameter_timeseries`](@ref), [`parameter_values_at_time`](@ref) and +[`parameter_values_at_state_time`](@ref). The function returned from `getp` will can be passed +`Colon()` (`:`) as the last argument to return the entire parameter timeseries for `p`, or +any index into the parameter timeseries for a subset of values. """ function getp(sys, p) symtype = symbolic_type(p) From 258355e21a9f9e56a0013125f614bc0a1c74fc77 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 7 Mar 2024 20:28:59 +0530 Subject: [PATCH 5/5] docs: update doc example to MTKv9 --- docs/src/usage.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/usage.md b/docs/src/usage.md index 4abbeff..bfebaff 100644 --- a/docs/src/usage.md +++ b/docs/src/usage.md @@ -168,11 +168,11 @@ sol2 = solve(prob, Tsit5()) σ_ρ_getter(sol) ``` -To set the entire parameter vector at once, [`parameter_values`](@ref) can be used -(note the usage of broadcasted assignment). +To set the entire parameter vector at once, [`setp`](@ref) can be used +(note that the order of symbols passed to `setp` must match the order of values in the array). ```@example Usage -parameter_values(prob) .= [29.0, 11.0, 2.5] +setp(prob, parameter_symbols(prob))(prob, [29.0, 11.0, 2.5]) parameter_values(prob) ```