diff --git a/Project.toml b/Project.toml index 5f8a183..3d23f4f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" -version = "0.1.5" +version = "0.2.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index 9e091c5..0e2122d 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -33,7 +33,7 @@ end isstochastic, event_times, - connection_index + ForeachConnectedSubsystem ) export @@ -232,6 +232,7 @@ add methods to this function if a subsystem or connection type has a discrete ev event_times(::Any) = () abstract type ConnectionRule end +Base.zero(::T) where {T <: ConnectionRule} = zero(T) struct NotConnected <: ConnectionRule end (::NotConnected)(l, r) = zero(promote_type(eltype(l), eltype(r))) struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected, AbstractMatrix{CR}}}}} @@ -245,23 +246,6 @@ Base.getindex(m::ConnectionMatrices, i) = m.matrices[i] Base.length(m::ConnectionMatrices) = length(m.matrices) Base.size(m::ConnectionMatrix{N}) where {N} = (N, N) -""" - connection_index(ConnType, M::ConnectionMatrices) - -give the first index `n` such that `M[n]` is a `ConnectionMatrix{N, ConnType} where {N}`, or throw an error if no such index exists. -""" -connection_index(::Type{ConnType}, M::ConnectionMatrices) where {ConnType} = _conn_index(ConnType, M.matrices, 1) -function _conn_index(::Type{ConnType}, tup::Tuple, i) where {ConnType} - if first(tup) isa ConnectionMatrix{N, ConnType} where {N} - return i - else - _conn_index(ConnType, Base.tail(tup), i+1) - end -end -@noinline _conn_index(::Type{ConnType}, ::Tuple{}, _) where {ConnType} = - error("ConnectionMatrices did not contain a ConnectionMatrix with connection type ", ConnType) - - abstract type GraphSystem end @kwdef struct ODEGraphSystem{CM <: ConnectionMatrices, S, P, EVT, CDEP, CCEP, Ns, SNM, PNM} <: GraphSystem diff --git a/src/graph_solve.jl b/src/graph_solve.jl index 4a7cf5a..232729a 100644 --- a/src/graph_solve.jl +++ b/src/graph_solve.jl @@ -257,11 +257,12 @@ function _continuous_affect!(integrator, sview = @view states_partitioned[i][j] pview = @view params_partitioned[i][j] sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) if continuous_events_require_inputs(sys) input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices) - apply_continuous_event!(integrator, sview, pview, sys, input) + apply_continuous_event!(integrator, sview, pview, sys, F, input) else - apply_continuous_event!(integrator, sview, pview, sys) + apply_continuous_event!(integrator, sview, pview, sys, F) end end offset += N @@ -326,34 +327,37 @@ end t) where {Len, NConn} quote @nexprs $Len i -> begin + # First we apply events to the states if has_discrete_events(eltype(states_partitioned[i])) - for j ∈ eachindex(states_partitioned[i]) - sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) - sview_dst = @view states_partitioned[i][j] - pview_dst = @view params_partitioned[i][j] - if discrete_event_condition(sys_dst, t) - if discrete_events_require_inputs(sys_dst) + @inbounds for j ∈ eachindex(states_partitioned[i]) + sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + sview = @view states_partitioned[i][j] + pview = @view params_partitioned[i][j] + if discrete_event_condition(sys, t) + F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) + if discrete_events_require_inputs(sys) input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices) - apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst, input) + apply_discrete_event!(integrator, sview, pview, sys, F, input) else - apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst) + apply_discrete_event!(integrator, sview, pview, sys, F) end end end end + # Then we do the connection events @nexprs $NConn nc -> begin @nexprs $Len k -> begin f = _discrete_connection_affect!(Val(i), Val(k), Val(nc), t, states_partitioned, params_partitioned, connection_matrices, integrator) foreach(f, eachindex(states_partitioned[i])) - end end end end end + function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, states_partitioned::NTuple{Len, Any}, params_partitioned::NTuple{Len, Any}, @@ -397,3 +401,80 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t, end end end + + +#----------------------------------------------------------------------- + +""" + ForeachConnectedSubsystem + +This is a callable struct which takes in a function, and then calls that function on each subsystem which has a connection leading to it +from some previously specified subsystem. + +That is, writing +```julia +F = ForeachConnectedSubsystem{k}(l, states_partitioned, params_partitioned, connection_matrices) + +F() do conn, sys_dst, states_view_dst, params_view_dst + [...] +end +``` +is like a type stable version of writing +``` +for i in eachindex(states_partitioned) + for nc in eachindex(connection_matrices) + M = connection_matrices[nc][i, k] + for j in eachindex(states_partitioned[k]) + conn = M[l, j] + if !iszero(conn) + states_view_dst = @view states_partitioned[i][j] + params_view_dst = @view params_partitioned[i][j] + sys_dst = Subsystem(states_view_dst[], params_view_dst[]) + [...] # <------- User code here + ends + end + end +end +``` +""" +struct ForeachConnectedSubsystem{k, Len, NConn, S, P, CMs} + l::Int + states_partitioned::S + params_partitioned::P + connection_matrices::CMs + function ForeachConnectedSubsystem{k}(l, + states_partitioned::NTuple{Len, Any}, + params_partitioned::NTuple{Len, Any}, + connection_matrices::ConnectionMatrices{NConn}) where {k, Len, NConn} + S = typeof(states_partitioned) + P = typeof(params_partitioned) + CMs = typeof(connection_matrices) + new{k, Len, NConn, S, P, CMs}(l, states_partitioned, params_partitioned, connection_matrices) + end +end + +@generated function ((;l, + states_partitioned, + params_partitioned, + connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F} + quote + @nexprs $Len i -> begin + @nexprs $NConn nc -> begin + M = connection_matrices[nc][k, i] + if M isa NotConnected + nothing + else + for j ∈ eachindex(states_partitioned[i]) + @inbounds conn = M[l, j] + if !iszero(conn) + @inbounds states_view_dst = @view states_partitioned[i][j] + @inbounds params_view_dst = @view params_partitioned[i][j] + sys_dst = Subsystem(states_view_dst[], params_view_dst[]) + f(conn, sys_dst, states_view_dst, params_view_dst) + end + end + end + end + end + end +end diff --git a/src/subsystems.jl b/src/subsystems.jl index 2a67466..d5dabed 100644 --- a/src/subsystems.jl +++ b/src/subsystems.jl @@ -198,12 +198,15 @@ Base.eltype(::Type{<:Subsystem{<:Any, T}}) where {T} = T #------------------------------------------------------------------------- +@generated function to_vec_o_states(state_data::NTuple{Len, Any}, ::Val{StateTypes}) where {Len, StateTypes} + state_types = StateTypes.parameters + Expr(:tuple, (:(VectorOfSubsystemStates{$(state_types[i])}(state_data[$i])) for i ∈ 1:Len)...) +end + struct VectorOfSubsystemStates{States, Mat <: AbstractMatrix} <: AbstractVector{States} data::Mat end -function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}( - v::AbstractMatrix{U} - ) where {Name, T, U, snames, Tup} +function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(v::AbstractMatrix{U}) where {Name, T, U, snames, Tup} V = promote_type(T,U) States = SubsystemStates{Name, V, NamedTuple{snames, NTuple{length(snames), V}}} VectorOfSubsystemStates{States, typeof(v)}(v) @@ -217,8 +220,8 @@ Base.size(v::VectorOfSubsystemStates{States}) where {States} = (size(v.data, 2), @inbounds States(view(v.data, 1:l, idx)) end -@noinline function sym_not_found_error(::Type{SubsystemStates{Name, T, NamedTuple{names}}}, s::Symbol) where {Name, T, names} - error("SubsystemStates{$Name} does not have a field $s, valid fields are $names") +@noinline function sym_not_found_error(::Type{S}, s::Symbol) where {S<:SubsystemStates} + error("$S does not have a field $s") end @propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, s::Symbol, idx::Integer) where {States <: SubsystemStates} @@ -247,7 +250,43 @@ end v.data[i, idx] = val end -@generated function to_vec_o_states(state_data::NTuple{Len, Any}, ::Val{StateTypes}) where {Len, StateTypes} - state_types = StateTypes.parameters - Expr(:tuple, (:(VectorOfSubsystemStates{$(state_types[i])}(state_data[$i])) for i ∈ 1:Len)...) + + +#------------------------------------------------------------------------- +struct SubsystemStatesView{States, Mat <: AbstractMatrix} <: AbstractArray{States, 0} + data::Mat + idx::Int +end +@propagate_inbounds function Base.view(v::VectorOfSubsystemStates{States, Mat}, idx::Int) where {States, Mat} + l = length(States) + @boundscheck checkbounds(v.data, 1:l, idx) + SubsystemStatesView{States, Mat}(v.data, idx) +end +Base.size(::SubsystemStatesView) = () +function Base.getindex(v::SubsystemStatesView{States}) where {States <: SubsystemStates} + l = length(States) + @inbounds States(view(v.data, 1:l, v.idx)) +end +@propagate_inbounds function Base.getindex(v::SubsystemStatesView{States}, s::Symbol) where {States <: SubsystemStates} + i = state_ind(States, s) + if isnothing(i) + sym_not_found_error(States, s) + end + @inbounds v.data[i, v.idx] +end + +@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, state::States) where {States <: SubsystemStates} + l = length(States) + idx = v.idx + @inbounds v.data[1:l, idx] .= Tuple(state) + v +end + +@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, val, s::Symbol) where {States <: SubsystemStates} + i = state_ind(States, s) + if isnothing(i) + sym_not_found_error(States, s) + end + @inbounds v.data[i, v.idx] = val + v end diff --git a/src/utils.jl b/src/utils.jl index 4b79c6c..2d295fa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,7 @@ valueof(x) = x # this just makes it so that I can easily replace all uses of `@inbounds ex` with just `ex`. macro inbounds(ex) - # ex + #esc(ex) esc(:($Base.@inbounds $ex)) end