From 320753eecab8dadf6ffe5e25cc10a5d9e2f06a64 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 May 2024 14:01:15 +0530 Subject: [PATCH] refactor: use callable structs for `setu` and `setp` --- src/parameter_indexing.jl | 57 +++++++++++++++++++-------------- src/state_indexing.jl | 36 +++++++++++---------- src/value_provider_interface.jl | 7 ++-- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index faa4551..a2b2a5f 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -45,7 +45,7 @@ function getp(sys, p) _getp(sys, symtype, elsymtype, p) end -struct GetParameterIndex{I} <: AbstractIndexer +struct GetParameterIndex{I} <: AbstractGetIndexer idx::I end @@ -78,7 +78,7 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) sys, NotSymbolic(), NotSymbolic(), idx) end -struct MultipleParameterGetters{G} +struct MultipleParameterGetters{G} <: AbstractGetIndexer getters::G end @@ -148,6 +148,17 @@ function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p) return getp(sys, collect(p)) end +struct ParameterHookWrapper{S, O} <: AbstractSetIndexer + setter::S + original_index::O +end + +function (phw::ParameterHookWrapper)(prob, args...) + res = phw.setter(prob, args...) + finalize_parameters_hook!(prob, phw.original_index) + res +end + """ setp(sys, p) @@ -165,33 +176,35 @@ function setp(sys, p; run_hook = true) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) return if run_hook - let _setter! = _setp(sys, symtype, elsymtype, p), p = p - function setter!(prob, args...) - res = _setter!(prob, args...) - finalize_parameters_hook!(prob, p) - res - end - end + return ParameterHookWrapper(_setp(sys, symtype, elsymtype, p), p) else _setp(sys, symtype, elsymtype, p) end end +struct SetParameterIndex{I} <: AbstractSetIndexer + idx::I +end + +function (spi::SetParameterIndex)(prob, val) + set_parameter!(prob, val, spi.idx) +end + function _setp(sys, ::NotSymbolic, ::NotSymbolic, p) - return let p = p - function setter!(sol, val) - set_parameter!(sol, val, p) - end - end + return SetParameterIndex(p) end function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) - return let idx = idx - function setter!(sol, val) - set_parameter!(sol, val, idx) - end - end + return SetParameterIndex(idx) +end + +struct MultipleSetters{S} <: AbstractSetIndexer + setters::S +end + +function (ms::MultipleSetters)(prob, val) + map((s!, v) -> s!(prob, v), ms.setters, val) end for (t1, t2) in [ @@ -201,11 +214,7 @@ for (t1, t2) in [ ] @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) setters = setp.((sys,), p; run_hook = false) - return let setters = setters - function setter!(sol, val) - map((s!, v) -> s!(sol, v), setters, val) - end - end + return MultipleSetters(setters) end end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 2fc50bf..d92d55d 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -33,7 +33,7 @@ function getu(sys, sym) _getu(sys, symtype, elsymtype, sym) end -struct GetStateIndex{I} <: AbstractIndexer +struct GetStateIndex{I} <: AbstractGetIndexer idx::I end function (gsi::GetStateIndex)(::Timeseries, prob) @@ -50,7 +50,7 @@ function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) return GetStateIndex(sym) end -struct GetpAtStateTime{G} <: AbstractIndexer +struct GetpAtStateTime{G} <: AbstractGetIndexer getter::G end @@ -65,12 +65,12 @@ function (g::GetpAtStateTime)(::NotTimeseries, prob) g.getter(prob) end -struct GetIndepvar <: AbstractIndexer end +struct GetIndepvar <: AbstractGetIndexer end (::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob) (::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i) -struct TimeDependentObservedFunction{F} <: AbstractIndexer +struct TimeDependentObservedFunction{F} <: AbstractGetIndexer obsfn::F end @@ -89,7 +89,7 @@ function (o::TimeDependentObservedFunction)(::NotTimeseries, prob) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end -struct TimeIndependentObservedFunction{F} <: AbstractIndexer +struct TimeIndependentObservedFunction{F} <: AbstractGetIndexer obsfn::F end @@ -116,7 +116,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) error("Invalid symbol $sym for `getu`") end -struct MultipleGetters{G} <: AbstractIndexer +struct MultipleGetters{G} <: AbstractGetIndexer getters::G end @@ -131,7 +131,7 @@ function (mg::MultipleGetters)(::NotTimeseries, prob) return map(g -> g(prob), mg.getters) end -struct AsTupleWrapper{G} <: AbstractIndexer +struct AsTupleWrapper{G} <: AbstractGetIndexer getter::G end @@ -201,18 +201,22 @@ function setu(sys, sym) _setu(sys, symtype, elsymtype, sym) end +struct SetStateIndex{I} <: AbstractSetIndexer + idx::I +end + +function (ssi::SetStateIndex)(prob, val) + set_state!(prob, val, ssi.idx) +end + function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym) - return function setter!(prob, val) - set_state!(prob, val, sym) - end + return SetStateIndex(sym) end function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return function setter!(prob, val) - set_state!(prob, val, idx) - end + return SetStateIndex(idx) elseif is_parameter(sys, sym) return setp(sys, sym) end @@ -226,16 +230,14 @@ for (t1, t2) in [ ] @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) setters = setu.((sys,), sym) - return function setter!(prob, val) - map((s!, v) -> s!(prob, v), setters, val) - end + return MultipleSetters(setters) end end function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return setu(sys, idx) + return MultipleSetters(SetStateIndex.(idx)) elseif is_parameter(sys, sym) return setp(sys, sym) end diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 5214841..8f7ee6b 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -139,5 +139,8 @@ function current_time end abstract type AbstractIndexer end -(ai::AbstractIndexer)(prob) = ai(is_timeseries(prob), prob) -(ai::AbstractIndexer)(prob, i) = ai(is_timeseries(prob), prob, i) +abstract type AbstractGetIndexer <: AbstractIndexer end +abstract type AbstractSetIndexer <: AbstractIndexer end + +(ai::AbstractGetIndexer)(prob) = ai(is_timeseries(prob), prob) +(ai::AbstractGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)