Skip to content

Commit

Permalink
feat: add setsym_oop
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 29, 2024
1 parent 3344f9f commit e8934c4
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 30 deletions.
34 changes: 4 additions & 30 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -726,52 +726,26 @@ function setp_oop(indp, sym)
return _setp_oop(indp, symtype, elsymtype, sym)
end

struct OOPSetter{I, D}
indp::I
idxs::D
end

function (os::OOPSetter)(valp, val)
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
end

function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
if os.idxs isa Union{Tuple, AbstractArray}
return remake_buffer(os.indp, parameter_values(valp), os.idxs, val)
else
return remake_buffer(os.indp, parameter_values(valp), (os.idxs,), (val,))
end
end

function _root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) != indp
return _root_indp(sc)
else
return indp
end
end

function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
return OOPSetter(_root_indp(indp), sym)
return OOPSetter(_root_indp(indp), sym, false)
end

function _setp_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false)
end

for (t1, t2) in [
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setp_oop(indp, ::NotSymbolic, ::$t1, sym::$t2)
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym))
return OOPSetter(_root_indp(indp), parameter_index.((indp,), sym), false)
end
end

function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
if is_parameter(indp, sym)
return OOPSetter(_root_indp(indp), parameter_index(indp, sym))
return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false)
end
error("$sym is not a valid parameter")
end
94 changes: 94 additions & 0 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,97 @@ end

const getu = getsym
const setu = setsym

"""
setsym_oop(indp, sym)
Return a function which takes a value provider `valp` and a value `val`, and returns
`state_values(valp), parameter_values(valp)` with the states/parameters in `sym` set to the
corresponding values in `val`. This allows changing the types of values stored, and leverages
[`remake_buffer`](@ref). Note that `sym` can be an index, a symbolic variable, or an
array/tuple of the aforementioned. All entries `s` in `sym` must satisfy `is_variable(indp, s)`
or `is_parameter(indp, s)`.
Requires that the value provider implement `state_values`, `parameter_values` and `remake_buffer`.
"""
function setsym_oop(indp, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))
return _setsym_oop(indp, symtype, elsymtype, sym)
end

struct FullSetter{S, P, I, J}
state_setter::S
param_setter::P
state_split::I
param_split::J
end

FullSetter(ssetter, psetter) = FullSetter(ssetter, psetter, nothing, nothing)

function (fs::FullSetter)(valp, val)
return fs.state_setter(valp, val[fs.state_split]),
fs.param_setter(valp, val[fs.param_split])
end

function (fs::FullSetter{Nothing})(valp, val)
return state_values(valp), fs.param_setter(valp, val)
end

function (fs::(FullSetter{S, Nothing} where {S}))(valp, val)
return fs.state_setter(valp, val), parameter_values(valp)
end

function (fs::(FullSetter{Nothing, Nothing}))(valp, val)
return state_values(valp), parameter_values(valp)
end

function _setsym_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
return FullSetter(OOPSetter(_root_indp(indp), sym, true), nothing)
end

function _setsym_oop(indp, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
if (idx = variable_index(indp, sym)) !== nothing
return FullSetter(OOPSetter(_root_indp(indp), idx, true), nothing)
elseif (idx = parameter_index(indp, sym)) !== nothing
return FullSetter(nothing, OOPSetter(_root_indp(indp), idx, false))
end
throw(NotVariableOrParameter("setsym_oop", sym))
end

for (t1, t2) in [
(ScalarSymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _setsym_oop(indp, ::NotSymbolic, ::$t1, sym::$t2)
vars = []
state_split = eltype(eachindex(sym))[]
pars = []
param_split = eltype(eachindex(sym))[]
for (i, s) in enumerate(sym)
if (idx = variable_index(indp, s)) !== nothing
push!(vars, idx)
push!(state_split, i)
elseif (idx = parameter_index(indp, s)) !== nothing
push!(pars, idx)
push!(param_split, i)
else
throw(NotVariableOrParameter("setsym_oop", s))
end
end
indp = _root_indp(indp)
return FullSetter(isempty(vars) ? nothing : OOPSetter(indp, identity.(vars), true),
isempty(pars) ? nothing : OOPSetter(indp, identity.(pars), false),
state_split, param_split)
end
end

function _setsym_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
if (idx = variable_index(indp, sym)) !== nothing
return setsym_oop(indp, idx)
elseif (idx = parameter_index(indp, sym)) !== nothing
return FullSetter(
nothing, OOPSetter(indp, idx isa AbstractArray ? idx : (idx,), false))
end
return setsym_oop(indp, collect(sym))
end
42 changes: 42 additions & 0 deletions src/value_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,35 @@ function (fn::Fix1Multiple)(args...)
fn.f(fn.arg, args...)
end

struct OOPSetter{I, D}
indp::I
idxs::D
is_state::Bool
end

function (os::OOPSetter)(valp, val)
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
return remake_buffer(os.indp, buffer, (os.idxs,), (val,))
end

function (os::OOPSetter)(valp, val::Union{Tuple, AbstractArray})
buffer = os.is_state ? state_values(valp) : parameter_values(valp)
if os.idxs isa Union{Tuple, AbstractArray}
return remake_buffer(os.indp, buffer, os.idxs, val)
else
return remake_buffer(os.indp, buffer, (os.idxs,), (val,))
end
end

function _root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) != indp
return _root_indp(sc)
else
return indp
end
end

###########
# Errors
###########
Expand Down Expand Up @@ -296,3 +325,16 @@ function Base.showerror(io::IO, err::MixedParameterTimeseriesIndexError)
indexes $(err.ts_idxs).
""")
end

struct NotVariableOrParameter <: Exception
fn::Any
sym::Any
end

function Base.showerror(io::IO, err::NotVariableOrParameter)
print(
io, """
`$(err.fn)` requires that the symbolic variable(s) passed to it satisfy `is_variable`
or `is_parameter`. Got `$(err.sym)` which is neither.
""")
end
24 changes: 24 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using SymbolicIndexingInterface
using SymbolicIndexingInterface: NotVariableOrParameter

struct FakeIntegrator{S, U, P, T}
sys::S
Expand Down Expand Up @@ -62,6 +63,9 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
set!(fi, newval)
end
@test get(fi) == newval

new_states = copy(state_values(fi))

set!(fi, val)
@test get(fi) == val

Expand All @@ -77,6 +81,15 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
@test get(u) == newval
set!(u, val)
@test get(u) == val

if sym isa Union{Vector, Tuple} && any(x -> x isa Union{AbstractArray, Tuple}, sym)
continue
end

setter = setsym_oop(sys, sym)
svals, pvals = setter(fi, newval)
@test svals new_states
@test pvals == parameter_values(fi)
end

for (sym, val, check_inference) in [
Expand Down Expand Up @@ -123,8 +136,17 @@ for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
set!(fi, newval)
end
@test get(fi) == newval

newu = copy(state_values(fi))
newp = copy(parameter_values(fi))

set!(fi, oldval)
@test get(fi) == oldval

oop_setter = setsym_oop(sys, sym)
uvals, pvals = oop_setter(fi, newval)
@test uvals newu
@test pvals newp
end

for (sym, val, check_inference) in [
Expand All @@ -137,6 +159,8 @@ for (sym, val, check_inference) in [
@inferred get(fi)
end
@test get(fi) == val

@test_throws NotVariableOrParameter setsym_oop(fi, sym)
end

struct FakeSolution{S, U, P, T}
Expand Down

0 comments on commit e8934c4

Please sign in to comment.