diff --git a/src/Catalyst.jl b/src/Catalyst.jl index 2c5c6003a2..21489620da 100644 --- a/src/Catalyst.jl +++ b/src/Catalyst.jl @@ -31,7 +31,8 @@ using ModelingToolkit: Symbolic, value, get_unknowns, get_ps, get_iv, get_system import ModelingToolkit: get_variables, namespace_expr, namespace_equation, get_variables!, modified_unknowns!, validate, namespace_variables, namespace_parameters, rename, renamespace, getname, flatten, - is_alg_equation, is_diff_equation + is_alg_equation, is_diff_equation, collect_vars!, + eqtype_supports_collect_vars # internal but needed ModelingToolkit functions import ModelingToolkit: check_variables, diff --git a/src/reaction.jl b/src/reaction.jl index e6202a73ac..efc96563ac 100644 --- a/src/reaction.jl +++ b/src/reaction.jl @@ -348,6 +348,30 @@ end MT.is_diff_equation(rx::Reaction) = false MT.is_alg_equation(rx::Reaction) = false +# MTK functions for extracting variables within equation type object +MT.eqtype_supports_collect_vars(rx::Reaction) = true +function MT.collect_vars!(unknowns, parameters, rx::Reaction, iv; depth = 0, + op = MT.Differential) + MT.collect_vars!(unknowns, parameters, rx.rate, iv; depth, op) + for sub in rx.substrates + MT.collect_vars!(unknowns, parameters, sub, iv; depth, op) + end + for prod in rx.products + MT.collect_vars!(unknowns, parameters, prod, iv; depth, op) + end + for substoich in rx.substoich + MT.collect_vars!(unknowns, parameters, substoich, iv; depth, op) + end + for prodstoich in rx.prodstoich + MT.collect_vars!(unknowns, parameters, prodstoich, iv; depth, op) + end + if hasnoisescaling(rx) + ns = getnoisescaling(rx) + MT.collect_vars!(unknowns, parameters, ns, iv; depth, op) + end + return nothing +end + """ get_symbolics(set, rx::Reaction) diff --git a/test/reactionsystem_core/reaction.jl b/test/reactionsystem_core/reaction.jl index 3d3b5396f7..596a68a433 100644 --- a/test/reactionsystem_core/reaction.jl +++ b/test/reactionsystem_core/reaction.jl @@ -3,7 +3,7 @@ # Fetch packages. using Catalyst, Test using Catalyst: get_symbolics -using ModelingToolkit: value, get_variables! +using ModelingToolkit: value, get_variables!, collect_vars!, eqtype_supports_collect_vars # Sets the default `t` to use. t = default_t() @@ -235,4 +235,20 @@ let @test Catalyst.hasmisc(r2) @test_throws Exception Catalyst.getmisc(r1) @test isequal(Catalyst.getmisc(r2), ('M', :M)) +end + +# tests for collect_vars! +let + t = default_t() + @variables E(t) F(t) + @species A(t) B(t) C(t) D(t) + @parameters k1, k2, η + + rx = Reaction(k1*E, [A, B], [C], [k2*D, 3], [F], metadata = [:noise_scaling => η]) + us = Set() + ps = Set() + @test eqtype_supports_collect_vars(rx) == true + collect_vars!(us, ps, rx, t) + @test us == Set((A, B, C, D, E, F)) + @test ps == Set((k1, k2, η)) end \ No newline at end of file