Skip to content

Commit

Permalink
use getu/setu directly on numeric containers
Browse files Browse the repository at this point in the history
  • Loading branch information
Datseris committed Jan 16, 2024
1 parent 39e8fe6 commit 3ade5d1
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 28 deletions.
38 changes: 15 additions & 23 deletions src/core/dynamicalsystem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,8 @@ state of the system is the full state space state, not the projected one
(this makes the most sense for allowing MTK-based indexing).
This function does not work with [`PoincareMap`](@ref) and [`ParallelDynamicalSystem`](@ref).
"""
function observe_state(ds::DynamicalSystem, index)
u = current_state(ds)
function observe_state(ds::DynamicalSystem, index, u = current_state(ds))
prob, sys, integ = referrenced_sciml_sys(ds)
observe_state(u, prob, sys, integ, index)
end
function observe_state(u, prob, sys, integ, index)
if !has_referrenced_sys(sys)
T = eltype(u)
if index isa Int
Expand All @@ -200,7 +196,7 @@ function observe_state(u, prob, sys, integ, index)
end
else
ugetter = SymbolicIndexingInterface.getu(sys, index)
return ugetter(integ)
return ugetter(u)
end
end

Expand All @@ -226,19 +222,17 @@ current_parameters(ds::DynamicalSystem) = ds.p
Return the specific parameter corresponding to `index`,
which can be anything given to [`set_parameter!`](@ref).
"""
current_parameter(ds::DynamicalSystem, index) = current_paramter(ds, current_parameters(ds), index)
function current_parameter(ds::DynamicalSystem, p, index)
function current_parameter(ds::DynamicalSystem, index, p = current_parameters(ds))
prob, sys, integ = referrenced_sciml_sys(ds)
if !has_referrenced_sys(sys)
return _get_parameter(p, index)
else # symbolic dispatch # TODO: Use the `observed` function directly on `p`.
else # symbolic dispatch
i = SymbolicIndexingInterface.getp(sys, index)
return i(prob)
return i(p)
end
end

# Dispatch for composite types as parameter containers
_get_parameter(p::Union{AbstractArray, AbstractDict}, index) = getindex(p, index)
# Dispatch for composite types as parameter containers
_get_parameter(p, index) = getproperty(p, index)

"""
Expand All @@ -248,10 +242,6 @@ Return the initial parameter container of `ds`.
This is never mutated and is set when initializing `ds`.
"""
initial_parameters(ds::DynamicalSystem) = ds.p0
function initial_parameters(ds::DynamicalSystem, index)
i = parameter_index(ds, index)
return _get_parameter(initial_parameters(ds), i)
end

"""
isdeterministic(ds::DynamicalSystem) → true/false
Expand Down Expand Up @@ -336,23 +326,25 @@ and the `value` to set it to. This function works for any type of parameter cont
(array/dictionary/composite types) provided the `index` is appropriate type.
The `index` can be a traditional Julia index (integer for arrays, key for dictionaries,
or symbol for composite types). However, it can also be a symbolic variable.
or symbol for composite types). It can also be a symbolic variable.
This is valid only for dynamical systems referring a ModelingToolkit.jl model
which also has `index` as one of its parameters.
"""
set_parameter!(ds::DynamicalSystem, args...) = _set_parameter!(ds, current_parameters(ds), args...)

function _set_parameter!(ds::DynamicalSystem, p, index, value)
function set_parameter!(ds::DynamicalSystem, index, value, p = current_parameters(ds))
# internal function is necessary so that we are able to call `u_modified!` for ODEs.
_set_parameter!(ds::DynamicalSystem, index, value, p)
end
function _set_parameter!(ds::DynamicalSystem, index, value, p = current_parameters(ds))
prob, sys = referrenced_sciml_sys(ds)
if !has_referrenced_sys(prob)
if !has_referrenced_sys(sys)
if p isa Union{AbstractArray, AbstractDict}
setindex!(p, value, index)
else
setproperty!(p, index, value)
end
else
i = SymbolicIndexingInterface.setp(sys, index)
i(prob, value)
set! = SymbolicIndexingInterface.setp(sys, index)
set!(p, value)
end
return
end
Expand Down
2 changes: 1 addition & 1 deletion src/core_systems/continuous_time_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ end

# This is here to ensure that `u_modified!` is called
function set_parameter!(ds::CoupledODEs, args...)
_set_parameter!(ds, current_parameters(ds), args...)
_set_parameter!(ds, args...)
u_modified!(ds.integ, true)
return
end
Expand Down
4 changes: 0 additions & 4 deletions test/mtk_integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ connected = compose(ODESystem(connections, name = :connected), fol_1, fol_2)

connected_simp = structural_simplify(connected)

using DynamicalSystemsBase.SymbolicIndexingInterface
# SymbolicIndexingInterface.observed(connected_simp, fol_1.τ)
# SymbolicIndexingInterface.state_values(arr::AbstractArray) = arr

u0 = [fol_1.x => -0.5,
fol_2.x => 1.0]

Expand Down

0 comments on commit 3ade5d1

Please sign in to comment.