Skip to content

Commit

Permalink
refactor: use callable structs for setu and setp
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 2, 2024
1 parent 635595a commit 320753e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 43 deletions.
57 changes: 33 additions & 24 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function getp(sys, p)
_getp(sys, symtype, elsymtype, p)
end

struct GetParameterIndex{I} <: AbstractIndexer
struct GetParameterIndex{I} <: AbstractGetIndexer
idx::I
end

Expand Down Expand Up @@ -78,7 +78,7 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
sys, NotSymbolic(), NotSymbolic(), idx)
end

struct MultipleParameterGetters{G}
struct MultipleParameterGetters{G} <: AbstractGetIndexer
getters::G
end

Expand Down Expand Up @@ -148,6 +148,17 @@ function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p)
return getp(sys, collect(p))
end

struct ParameterHookWrapper{S, O} <: AbstractSetIndexer
setter::S
original_index::O
end

function (phw::ParameterHookWrapper)(prob, args...)
res = phw.setter(prob, args...)
finalize_parameters_hook!(prob, phw.original_index)
res
end

"""
setp(sys, p)
Expand All @@ -165,33 +176,35 @@ function setp(sys, p; run_hook = true)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
return if run_hook
let _setter! = _setp(sys, symtype, elsymtype, p), p = p
function setter!(prob, args...)
res = _setter!(prob, args...)
finalize_parameters_hook!(prob, p)
res
end
end
return ParameterHookWrapper(_setp(sys, symtype, elsymtype, p), p)
else
_setp(sys, symtype, elsymtype, p)
end
end

struct SetParameterIndex{I} <: AbstractSetIndexer
idx::I
end

function (spi::SetParameterIndex)(prob, val)
set_parameter!(prob, val, spi.idx)
end

function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)
return let p = p
function setter!(sol, val)
set_parameter!(sol, val, p)
end
end
return SetParameterIndex(p)
end

function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return let idx = idx
function setter!(sol, val)
set_parameter!(sol, val, idx)
end
end
return SetParameterIndex(idx)
end

struct MultipleSetters{S} <: AbstractSetIndexer
setters::S
end

function (ms::MultipleSetters)(prob, val)
map((s!, v) -> s!(prob, v), ms.setters, val)
end

for (t1, t2) in [
Expand All @@ -201,11 +214,7 @@ for (t1, t2) in [
]
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
setters = setp.((sys,), p; run_hook = false)
return let setters = setters
function setter!(sol, val)
map((s!, v) -> s!(sol, v), setters, val)
end
end
return MultipleSetters(setters)
end
end

Expand Down
36 changes: 19 additions & 17 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function getu(sys, sym)
_getu(sys, symtype, elsymtype, sym)
end

struct GetStateIndex{I} <: AbstractIndexer
struct GetStateIndex{I} <: AbstractGetIndexer
idx::I
end
function (gsi::GetStateIndex)(::Timeseries, prob)
Expand All @@ -50,7 +50,7 @@ function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
return GetStateIndex(sym)
end

struct GetpAtStateTime{G} <: AbstractIndexer
struct GetpAtStateTime{G} <: AbstractGetIndexer
getter::G
end

Expand All @@ -65,12 +65,12 @@ function (g::GetpAtStateTime)(::NotTimeseries, prob)
g.getter(prob)
end

struct GetIndepvar <: AbstractIndexer end
struct GetIndepvar <: AbstractGetIndexer end

(::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob)
(::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i)

struct TimeDependentObservedFunction{F} <: AbstractIndexer
struct TimeDependentObservedFunction{F} <: AbstractGetIndexer
obsfn::F
end

Expand All @@ -89,7 +89,7 @@ function (o::TimeDependentObservedFunction)(::NotTimeseries, prob)
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
end

struct TimeIndependentObservedFunction{F} <: AbstractIndexer
struct TimeIndependentObservedFunction{F} <: AbstractGetIndexer
obsfn::F
end

Expand All @@ -116,7 +116,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
error("Invalid symbol $sym for `getu`")
end

struct MultipleGetters{G} <: AbstractIndexer
struct MultipleGetters{G} <: AbstractGetIndexer
getters::G
end

Expand All @@ -131,7 +131,7 @@ function (mg::MultipleGetters)(::NotTimeseries, prob)
return map(g -> g(prob), mg.getters)
end

struct AsTupleWrapper{G} <: AbstractIndexer
struct AsTupleWrapper{G} <: AbstractGetIndexer
getter::G
end

Expand Down Expand Up @@ -201,18 +201,22 @@ function setu(sys, sym)
_setu(sys, symtype, elsymtype, sym)
end

struct SetStateIndex{I} <: AbstractSetIndexer
idx::I
end

function (ssi::SetStateIndex)(prob, val)
set_state!(prob, val, ssi.idx)
end

function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym)
return function setter!(prob, val)
set_state!(prob, val, sym)
end
return SetStateIndex(sym)
end

function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return function setter!(prob, val)
set_state!(prob, val, idx)
end
return SetStateIndex(idx)
elseif is_parameter(sys, sym)
return setp(sys, sym)
end
Expand All @@ -226,16 +230,14 @@ for (t1, t2) in [
]
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
setters = setu.((sys,), sym)
return function setter!(prob, val)
map((s!, v) -> s!(prob, v), setters, val)
end
return MultipleSetters(setters)
end
end

function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return setu(sys, idx)
return MultipleSetters(SetStateIndex.(idx))
elseif is_parameter(sys, sym)
return setp(sys, sym)
end
Expand Down
7 changes: 5 additions & 2 deletions src/value_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,8 @@ function current_time end

abstract type AbstractIndexer end

(ai::AbstractIndexer)(prob) = ai(is_timeseries(prob), prob)
(ai::AbstractIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)
abstract type AbstractGetIndexer <: AbstractIndexer end
abstract type AbstractSetIndexer <: AbstractIndexer end

(ai::AbstractGetIndexer)(prob) = ai(is_timeseries(prob), prob)
(ai::AbstractGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)

0 comments on commit 320753e

Please sign in to comment.