Skip to content

Commit

Permalink
fixup! refactor: rework discrete indexing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 21, 2024
1 parent aff02b3 commit 5a8d9a5
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 91 deletions.
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ allvariables
```@docs
observed
parameter_observed
ParameterObservedFunction
```

#### Parameter timeseries
Expand All @@ -46,6 +45,8 @@ may change at different times.
is_timeseries_parameter
timeseries_parameter_index
ParameterTimeseriesIndex
get_all_timeseries_indexes
ContinuousTimeseries
```

## Value provider interface
Expand Down
6 changes: 2 additions & 4 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,8 @@ end
In case a type does not support such observed quantities, `is_observed` must be
defined to always return `false`, and `observed` does not need to be implemented.

The same process can be followed for [`parameter_observed`](@ref), with the exception
that the returned function must not have `u` in its signature, and must be wrapped in a
[`ParameterObservedFunction`](@ref). In-place versions can also be implemented for
`parameter_observed`.
The same process can be followed for [`parameter_observed`](@ref). In-place versions
can also be implemented for `parameter_observed`.

#### Note about constant structure

Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include("trait.jl")
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
is_timeseries_parameter, timeseries_parameter_index, ParameterTimeseriesIndex,
parameter_symbols, is_independent_variable, independent_variable_symbols,
is_observed, observed, parameter_observed, ParameterObservedFunction,
is_observed, observed, parameter_observed,
ContinuousTimeseries, get_all_timeseries_indexes,
is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values,
Expand Down
38 changes: 5 additions & 33 deletions src/index_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,42 +98,14 @@ function timeseries_parameter_index(indp, sym)
end
end

"""
struct ParameterObservedFunction
function ParameterObservedFunction(timeseries_idx, observed_fn::Function)
function ParameterObservedFunction(observed_fn::Function)
A struct which stores the parameter observed function and optional timeseries index for
a particular symbol. The timeseries index is optional and may be omitted. Specifying the
timeseries index allows [`getp`](@ref) to return the appropriate timeseries for a
timeseries parameter.
For time-dependent index providers (where `is_time_dependent(indp)`) `observed_fn` must
have the signature `(p, t) -> [values...]`. For non-time-dependent index providers
(where `!is_time_dependent(indp)`) `observed_fn` must have the signature
`(p) -> [values...]`. To support in-place `getp` methods, `observed_fn` must also have an
additional method which takes `buffer::AbstractArray` as its first argument. The required
values must be written to the buffer in the appropriate order.
"""
struct ParameterObservedFunction{I, F <: Function}
timeseries_idx::I
observed_fn::F
end

function ParameterObservedFunction(ts_idx, f)
ParameterObservedFunction{typeof(ts_idx), typeof(f)}(ts_idx, f)
end
ParameterObservedFunction(f) = ParameterObservedFunction(nothing, f)

"""
parameter_observed(indp, sym)
Return the observed function of `sym` in `indp` as a [`ParameterObservedFunction`](@ref).
If `sym` only involves non-timeseries parameters, the timeseries index in the
`ParameterObservedFunction` should be `nothing`. If `sym` involves timeseries parameters,
the timeseries index should be specified in the `ParameterObservedFunction`. If `sym`
involves timeseries parameters from multiple different timeseries, all the relevant
timeseries indexes should be specified in the `ParameterObservedFunction` as a `Vector`.
Return the observed function of `sym` in `indp`. This functions similarly to
[`observed`](@ref) except that `u` is not an argument of the returned function. For time-
dependent systems, the returned function must have the signature `(p, t) -> [values...]`.
For time-independent systems, the returned function must have the signature
`(p) -> [values...]`.
By default, this function returns `nothing`, indicating that the index provider does not
support generating parameter observed functions.
Expand Down
16 changes: 10 additions & 6 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,10 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
throw(ArgumentError("Index provider does not support `parameter_observed`; cannot use generate function for $p"))
end
if !is_time_dependent(sys)
return GetParameterObservedNoTime(pofn.observed_fn)
return GetParameterObservedNoTime(pofn)
end
return GetParameterObserved{false}(pofn.timeseries_idx, pofn.observed_fn)
ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p))
return GetParameterObserved{false}(ts_idxs, pofn)
end
error("Invalid symbol $p for `getp`")
end
Expand Down Expand Up @@ -572,18 +573,20 @@ for (t1, t2) in [
# `getp` errors on older MTK that doesn't support `parameter_observed`.
getters = getp.((sys,), p)
num_observed = count(is_observed_getter, getters)
p_arr = p isa Tuple ? collect(p) : p

if num_observed == 0
return MultipleParametersGetter(getters)
else
pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p)
pofn = parameter_observed(sys, p_arr)
if pofn === nothing
return MultipleParametersGetter.(getters)
end
if is_time_dependent(sys)
getter = GetParameterObserved{true}(pofn.timeseries_idx, pofn.observed_fn)
ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p_arr))
getter = GetParameterObserved{true}(ts_idxs, pofn)
else
getter = GetParameterObservedNoTime(pofn.observed_fn)
getter = GetParameterObservedNoTime(pofn)
end
return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter
end
Expand Down Expand Up @@ -677,7 +680,8 @@ function _setp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return setp(sys, idx; run_hook = false)
elseif is_observed(sys, p) && (pobsfn = parameter_observed(sys, p)) !== nothing
return GetParameterObserved{false}(pobsfn.timeseries_idx, pobsfn.observed_fn)
ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p))
return GetParameterObserved{false}(ts_idxs, pobsfn.observed_fn)
end
return setp(sys, collect(p); run_hook = false)
end
5 changes: 3 additions & 2 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ for (t1, t2) in [
if isempty(sym)
return MultipleGetters(sym)
end
ts_idxs = get_all_timeseries_indexes(sys, sym isa Tuple ? collect(sym) : sym)
sym_arr = sym isa Tuple ? collect(sym) : sym
ts_idxs = get_all_timeseries_indexes(sys, sym_arr)
if ContinuousTimeseries() in ts_idxs && length(ts_idxs) > 1
throw(MixedContinuousParameterTimeseriesError(sym))
end
Expand All @@ -187,7 +188,7 @@ for (t1, t2) in [
getters = getu.((sys,), sym)
return MultipleGetters(getters)
else
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
obs = observed(sys, sym_arr)
getter = if is_time_dependent(sys)
TimeDependentObservedFunction(obs)
else
Expand Down
34 changes: 4 additions & 30 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,26 +245,13 @@ function parameter_observed(sc::SymbolCache, expr::Expr)
if is_time_dependent(sc)
exs = ExpressionSearcher()
exs(sc, expr)
ts_idxs = Set()
for p in exs.parameters
is_timeseries_parameter(sc, p) || continue
push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx)
end
f = let fn = observed(sc, expr)
return let fn = observed(sc, expr)
f1(p, t) = fn(nothing, p, t)
end
if isempty(ts_idxs)
return ParameterObservedFunction(f)
elseif length(ts_idxs) == 1
return ParameterObservedFunction(only(ts_idxs), f)
else
return ParameterObservedFunction(collect(ts_idxs), f)
end
else
f = let fn = observed(sc, expr)
return let fn = observed(sc, expr)
f2(p) = fn(nothing, p)
end
return ParameterObservedFunction(nothing, f)
end
end

Expand All @@ -277,29 +264,16 @@ function parameter_observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple})
if is_time_dependent(sc)
exs = ExpressionSearcher()
exs(sc, to_expr(exprs))
ts_idxs = Set()
for p in exs.parameters
is_timeseries_parameter(sc, p) || continue
push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx)
end

f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs)
return let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs)
f1(p, t) = oop(nothing, p, t)
f1(buffer, p, t) = iip(buffer, nothing, p, t)
end
if isempty(ts_idxs)
return ParameterObservedFunction(f)
elseif length(ts_idxs) == 1
return ParameterObservedFunction(only(ts_idxs), f)
else
return ParameterObservedFunction(collect(ts_idxs), f)
end
else
f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs)
return let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs)
f2(p) = oop(nothing, p)
f2(buffer, p) = iip(buffer, nothing, p)
end
return ParameterObservedFunction(nothing, f)
end
end

Expand Down
11 changes: 9 additions & 2 deletions src/value_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,15 @@ as_timeseries_indexer(::IndexerNotTimeseries, x) = x
as_not_timeseries_indexer(x) = as_not_timeseries_indexer(is_indexer_timeseries(x), x)
as_not_timeseries_indexer(::IndexerNotTimeseries, x) = x

struct MultipleTimeseriesIndexes{I}
idxs::Vector{I}
function _postprocess_tsidxs(ts_idxs)
delete!(ts_idxs, ContinuousTimeseries())
if isempty(ts_idxs)
return nothing
elseif length(ts_idxs) == 1
return only(ts_idxs)
else
return collect(ts_idxs)
end
end

struct CallWith{A}
Expand Down
2 changes: 1 addition & 1 deletion test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ for (sym, val, buffer, check_inference) in [
([:(a + d), :(a * d)], [aval + dval, aval * dval], zeros(2), true),
((:(a + d), :(a * d)), (aval + dval, aval * dval), zeros(2), true)
]
getter = getp(sys, sym)
getter = getp(fs, sym)
@test is_indexer_timeseries(getter) isa IndexerNotTimeseries
test_inplace = buffer !== nothing
is_observed = sym isa Expr ||
Expand Down
19 changes: 8 additions & 11 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ obsfn5 = observed(sc, (:(x + a), :(y + b)))
@test_throws TypeError observed(sc, [:(x + a), 2])
@test_throws TypeError observed(sc, (:(x + a), 2))

pobsfn1 = parameter_observed(sc, :(a + b + t)).observed_fn
pobsfn1 = parameter_observed(sc, :(a + b + t))
@test pobsfn1(2ones(2), 3.0) == 7.0
pobsfn2 = parameter_observed(sc, [:(a + b + t), :(a + t)]).observed_fn
pobsfn2 = parameter_observed(sc, [:(a + b + t), :(a + t)])
@test pobsfn2(2ones(2), 3.0) == [7.0, 5.0]
buffer = zeros(2)
pobsfn2(buffer, 2ones(2), 3.0)
@test buffer == [7.0, 5.0]
pobsfn3 = parameter_observed(sc, (:(a + b + t), :(a + t))).observed_fn
pobsfn3 = parameter_observed(sc, (:(a + b + t), :(a + t)))
@test pobsfn3(2ones(2), 3.0) == (7.0, 5.0)
buffer = zeros(2)
pobsfn3(buffer, 2ones(2), 3.0)
Expand All @@ -67,14 +67,11 @@ pobsfn3(buffer, 2ones(2), 3.0)
sc = SymbolCache([:x, :y], [:a, :b, :c], :t;
timeseries_parameters = Dict(
:b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1)))
@test parameter_observed(sc, :(a + c)).timeseries_idx == 2
@test parameter_observed(sc, [:a, :c]).timeseries_idx == 2
@test parameter_observed(sc, (:a, :c)).timeseries_idx == 2
@test parameter_observed(sc, :(2a)).timeseries_idx === nothing
@test parameter_observed(sc, [:(2a), :(3a)]).timeseries_idx === nothing
@test parameter_observed(sc, (:(2a), :(3a))).timeseries_idx === nothing
@test sort(parameter_observed(sc, [:b, :c]).timeseries_idx) == [1, 2]
@test sort(parameter_observed(sc, (:b, :c)).timeseries_idx) == [1, 2]
@test only(get_all_timeseries_indexes(sc, :(a + c))) == 2
@test only(get_all_timeseries_indexes(sc, [:a, :c])) == 2
@test isempty(get_all_timeseries_indexes(sc, :(2a)))
@test isempty(get_all_timeseries_indexes(sc, [:(2a), :(3a)]))
@test sort(collect(get_all_timeseries_indexes(sc, [:b, :c]))) == [1, 2]

@test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t;
timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1)))
Expand Down

0 comments on commit 5a8d9a5

Please sign in to comment.