Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ForeachConnectedSubsystem #9

Merged
merged 5 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GraphDynamics"
uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
version = "0.1.5"
version = "0.2.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
20 changes: 2 additions & 18 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
isstochastic,

event_times,
connection_index
ForeachConnectedSubsystem
)

export
Expand Down Expand Up @@ -232,6 +232,7 @@ add methods to this function if a subsystem or connection type has a discrete ev
event_times(::Any) = ()

abstract type ConnectionRule end
Base.zero(::T) where {T <: ConnectionRule} = zero(T)
struct NotConnected <: ConnectionRule end
(::NotConnected)(l, r) = zero(promote_type(eltype(l), eltype(r)))
struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected, AbstractMatrix{CR}}}}}
Expand All @@ -245,23 +246,6 @@ Base.getindex(m::ConnectionMatrices, i) = m.matrices[i]
Base.length(m::ConnectionMatrices) = length(m.matrices)
Base.size(m::ConnectionMatrix{N}) where {N} = (N, N)

"""
connection_index(ConnType, M::ConnectionMatrices)

give the first index `n` such that `M[n]` is a `ConnectionMatrix{N, ConnType} where {N}`, or throw an error if no such index exists.
"""
connection_index(::Type{ConnType}, M::ConnectionMatrices) where {ConnType} = _conn_index(ConnType, M.matrices, 1)
function _conn_index(::Type{ConnType}, tup::Tuple, i) where {ConnType}
if first(tup) isa ConnectionMatrix{N, ConnType} where {N}
return i
else
_conn_index(ConnType, Base.tail(tup), i+1)
end
end
@noinline _conn_index(::Type{ConnType}, ::Tuple{}, _) where {ConnType} =
error("ConnectionMatrices did not contain a ConnectionMatrix with connection type ", ConnType)


abstract type GraphSystem end

@kwdef struct ODEGraphSystem{CM <: ConnectionMatrices, S, P, EVT, CDEP, CCEP, Ns, SNM, PNM} <: GraphSystem
Expand Down
103 changes: 92 additions & 11 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ function _continuous_affect!(integrator,
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if continuous_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_continuous_event!(integrator, sview, pview, sys, input)
apply_continuous_event!(integrator, sview, pview, sys, F, input)
else
apply_continuous_event!(integrator, sview, pview, sys)
apply_continuous_event!(integrator, sview, pview, sys, F)
end
end
offset += N
Expand Down Expand Up @@ -326,34 +327,37 @@ end
t) where {Len, NConn}
quote
@nexprs $Len i -> begin
# First we apply events to the states
if has_discrete_events(eltype(states_partitioned[i]))
for j ∈ eachindex(states_partitioned[i])
sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview_dst = @view states_partitioned[i][j]
pview_dst = @view params_partitioned[i][j]
if discrete_event_condition(sys_dst, t)
if discrete_events_require_inputs(sys_dst)
@inbounds for j ∈ eachindex(states_partitioned[i])
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
if discrete_event_condition(sys, t)
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if discrete_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst, input)
apply_discrete_event!(integrator, sview, pview, sys, F, input)
else
apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst)
apply_discrete_event!(integrator, sview, pview, sys, F)
end
end
end
end
# Then we do the connection events
@nexprs $NConn nc -> begin
@nexprs $Len k -> begin
f = _discrete_connection_affect!(Val(i), Val(k), Val(nc), t,
states_partitioned, params_partitioned, connection_matrices,
integrator)
foreach(f, eachindex(states_partitioned[i]))

end
end
end
end
end


function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
states_partitioned::NTuple{Len, Any},
params_partitioned::NTuple{Len, Any},
Expand Down Expand Up @@ -397,3 +401,80 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
end
end
end


#-----------------------------------------------------------------------

"""
ForeachConnectedSubsystem

This is a callable struct which takes in a function, and then calls that function on each subsystem which has a connection leading to it
from some previously specified subsystem.

That is, writing
```julia
F = ForeachConnectedSubsystem{k}(l, states_partitioned, params_partitioned, connection_matrices)

F() do conn, sys_dst, states_view_dst, params_view_dst
[...]
end
```
is like a type stable version of writing
```
for i in eachindex(states_partitioned)
for nc in eachindex(connection_matrices)
M = connection_matrices[nc][i, k]
for j in eachindex(states_partitioned[k])
conn = M[l, j]
if !iszero(conn)
states_view_dst = @view states_partitioned[i][j]
params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
[...] # <------- User code here
ends
end
end
end
```
"""
struct ForeachConnectedSubsystem{k, Len, NConn, S, P, CMs}
l::Int
states_partitioned::S
params_partitioned::P
connection_matrices::CMs
function ForeachConnectedSubsystem{k}(l,
states_partitioned::NTuple{Len, Any},
params_partitioned::NTuple{Len, Any},
connection_matrices::ConnectionMatrices{NConn}) where {k, Len, NConn}
S = typeof(states_partitioned)
P = typeof(params_partitioned)
CMs = typeof(connection_matrices)
new{k, Len, NConn, S, P, CMs}(l, states_partitioned, params_partitioned, connection_matrices)
end
end

@generated function ((;l,
states_partitioned,
params_partitioned,
connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F}
quote
@nexprs $Len i -> begin
@nexprs $NConn nc -> begin
M = connection_matrices[nc][k, i]
if M isa NotConnected
nothing
else
for j ∈ eachindex(states_partitioned[i])
@inbounds conn = M[l, j]
if !iszero(conn)
@inbounds states_view_dst = @view states_partitioned[i][j]
@inbounds params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
f(conn, sys_dst, states_view_dst, params_view_dst)
end
end
end
end
end
end
end
55 changes: 47 additions & 8 deletions src/subsystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,15 @@ Base.eltype(::Type{<:Subsystem{<:Any, T}}) where {T} = T

#-------------------------------------------------------------------------

@generated function to_vec_o_states(state_data::NTuple{Len, Any}, ::Val{StateTypes}) where {Len, StateTypes}
state_types = StateTypes.parameters
Expr(:tuple, (:(VectorOfSubsystemStates{$(state_types[i])}(state_data[$i])) for i ∈ 1:Len)...)
end

struct VectorOfSubsystemStates{States, Mat <: AbstractMatrix} <: AbstractVector{States}
data::Mat
end
function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(
v::AbstractMatrix{U}
) where {Name, T, U, snames, Tup}
function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(v::AbstractMatrix{U}) where {Name, T, U, snames, Tup}
V = promote_type(T,U)
States = SubsystemStates{Name, V, NamedTuple{snames, NTuple{length(snames), V}}}
VectorOfSubsystemStates{States, typeof(v)}(v)
Expand All @@ -217,8 +220,8 @@ Base.size(v::VectorOfSubsystemStates{States}) where {States} = (size(v.data, 2),
@inbounds States(view(v.data, 1:l, idx))
end

@noinline function sym_not_found_error(::Type{SubsystemStates{Name, T, NamedTuple{names}}}, s::Symbol) where {Name, T, names}
error("SubsystemStates{$Name} does not have a field $s, valid fields are $names")
@noinline function sym_not_found_error(::Type{S}, s::Symbol) where {S<:SubsystemStates}
error("$S does not have a field $s")
end

@propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, s::Symbol, idx::Integer) where {States <: SubsystemStates}
Expand Down Expand Up @@ -247,7 +250,43 @@ end
v.data[i, idx] = val
end

@generated function to_vec_o_states(state_data::NTuple{Len, Any}, ::Val{StateTypes}) where {Len, StateTypes}
state_types = StateTypes.parameters
Expr(:tuple, (:(VectorOfSubsystemStates{$(state_types[i])}(state_data[$i])) for i ∈ 1:Len)...)


#-------------------------------------------------------------------------
struct SubsystemStatesView{States, Mat <: AbstractMatrix} <: AbstractArray{States, 0}
data::Mat
idx::Int
end
@propagate_inbounds function Base.view(v::VectorOfSubsystemStates{States, Mat}, idx::Int) where {States, Mat}
l = length(States)
@boundscheck checkbounds(v.data, 1:l, idx)
SubsystemStatesView{States, Mat}(v.data, idx)
end
Base.size(::SubsystemStatesView) = ()
function Base.getindex(v::SubsystemStatesView{States}) where {States <: SubsystemStates}
l = length(States)
@inbounds States(view(v.data, 1:l, v.idx))
end
@propagate_inbounds function Base.getindex(v::SubsystemStatesView{States}, s::Symbol) where {States <: SubsystemStates}
i = state_ind(States, s)
if isnothing(i)
sym_not_found_error(States, s)
end
@inbounds v.data[i, v.idx]
end

@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, state::States) where {States <: SubsystemStates}
l = length(States)
idx = v.idx
@inbounds v.data[1:l, idx] .= Tuple(state)
v
end

@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, val, s::Symbol) where {States <: SubsystemStates}
i = state_ind(States, s)
if isnothing(i)
sym_not_found_error(States, s)
end
@inbounds v.data[i, v.idx] = val
v
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ valueof(x) = x

# this just makes it so that I can easily replace all uses of `@inbounds ex` with just `ex`.
macro inbounds(ex)
# ex
#esc(ex)
esc(:($Base.@inbounds $ex))
end

Expand Down
Loading