Skip to content

Commit

Permalink
Merge pull request #929 from SciML/fix_expand_funcs
Browse files Browse the repository at this point in the history
[v14 - Ready] Fix bug where `expand_registered_functions` mutated original reaction system
  • Loading branch information
TorkelE authored Jun 12, 2024
2 parents 6ccbc1f + 2f07b71 commit 8f8290c
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 11 deletions.
69 changes: 59 additions & 10 deletions src/registered_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,37 @@ function Symbolics.derivative(::typeof(hillar), args::NTuple{5, Any}, ::Val{5})
(args[1]^args[5] + args[2]^args[5] + args[4]^args[5])^2
end

# Tuple storing all registered function (for use in various functionalities).
const registered_funcs = (mm, mmr, hill, hillr, hillar)

### Custom CRN FUnction-related Functions ###

"""
expand_registered_functions(expr)
expand_registered_functions(in)
Takes an expression, and expands registered function expressions. E.g. `mm(X,v,K)` is replaced with v*X/(X+K). Currently supported functions: `mm`, `mmr`, `hill`, `hillr`, and `hill`.
Takes an expression, and expands registered function expressions. E.g. `mm(X,v,K)` is replaced
with v*X/(X+K). Currently supported functions: `mm`, `mmr`, `hill`, `hillr`, and `hill`. Can
be applied to a reaction system, a reaction, an equation, or a symbolic expression. The input
is not modified, while an output with any functions expanded is returned. If applied to a
reaction system model, any cached network properties are reset.
"""
function expand_registered_functions(expr)
iscall(expr) || return expr
if hasnode(is_catalyst_function, expr)
expr = replacenode(expr, expand_catalyst_function)
end
return expr
end

# Checks whether an expression corresponds to a catalyst function call (e.g. `mm(X,v,K)`).
function is_catalyst_function(expr)
iscall(expr) || (return false)
return operation(expr) in registered_funcs
end

# If the input expression corresponds to a catalyst function call (e.g. `mm(X,v,K)`), returns
# it in its expanded form. If not, returns the input expression.
function expand_catalyst_function(expr)
is_catalyst_function(expr) || (return expr)
args = arguments(expr)
if operation(expr) == Catalyst.mm
return args[2] * args[1] / (args[1] + args[3])
Expand All @@ -131,23 +153,50 @@ function expand_registered_functions(expr)
return args[3] * (args[1]^args[5]) /
((args[1])^args[5] + (args[2])^args[5] + (args[4])^args[5])
end
for i in 1:length(args)
args[i] = expand_registered_functions(args[i])
end
return expr
end

# If applied to a Reaction, return a reaction with its rate modified.
function expand_registered_functions(rx::Reaction)
Reaction(expand_registered_functions(rx.rate), rx.substrates, rx.products,
rx.substoich, rx.prodstoich, rx.netstoich, rx.only_use_rate, rx.metadata)
end
# If applied to a Equation, returns it with it applied to lhs and rhs

# If applied to a Equation, returns it with it applied to lhs and rhs.
function expand_registered_functions(eq::Equation)
return expand_registered_functions(eq.lhs) ~ expand_registered_functions(eq.rhs)
end

# If applied to a continuous event, returns it applied to eqs and affect.
function expand_registered_functions(ce::ModelingToolkit.SymbolicContinuousCallback)
eqs = expand_registered_functions(ce.eqs)
affect = expand_registered_functions(ce.affect)
return ModelingToolkit.SymbolicContinuousCallback(eqs, affect)
end

# If applied to a discrete event, returns it applied to condition and affects.
function expand_registered_functions(de::ModelingToolkit.SymbolicDiscreteCallback)
condition = expand_registered_functions(de.condition)
affects = expand_registered_functions(de.affects)
return ModelingToolkit.SymbolicDiscreteCallback(condition, affects)
end

# If applied to a vector, applies it to every element in the vector.
function expand_registered_functions(vec::Vector)
return [Catalyst.expand_registered_functions(element) for element in vec]
end

# If applied to a ReactionSystem, applied function to all Reactions and other Equations, and return updated system.
# Currently, `ModelingToolkit.has_X_events` returns `true` even if event vector is empty (hence
# this function cannot be used).
function expand_registered_functions(rs::ReactionSystem)
@set! rs.eqs = [Catalyst.expand_registered_functions(eq) for eq in get_eqs(rs)]
@set! rs.rxs = [Catalyst.expand_registered_functions(rx) for rx in get_rxs(rs)]
@set! rs.eqs = Catalyst.expand_registered_functions(get_eqs(rs))
@set! rs.rxs = Catalyst.expand_registered_functions(get_rxs(rs))
if !isempty(ModelingToolkit.get_continuous_events(rs))
@set! rs.continuous_events = Catalyst.expand_registered_functions(ModelingToolkit.get_continuous_events(rs))
end
if !isempty(ModelingToolkit.get_discrete_events(rs))
@set! rs.discrete_events = Catalyst.expand_registered_functions(ModelingToolkit.get_discrete_events(rs))
end
reset_networkproperties!(rs)
return rs
end
62 changes: 61 additions & 1 deletion test/reactionsystem_core/custom_crn_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Fetch packages.
using Catalyst, Test
using ModelingToolkit: get_continuous_events, get_discrete_events
using Symbolics: derivative

# Sets stable rng number.
Expand Down Expand Up @@ -154,4 +155,63 @@ let
@test isequal(Catalyst.expand_registered_functions(eq3), 0 ~ V * (X^N) / (X^N + K^N))
@test isequal(Catalyst.expand_registered_functions(eq4), 0 ~ V * (K^N) / (X^N + K^N))
@test isequal(Catalyst.expand_registered_functions(eq5), 0 ~ V * (X^N) / (X^N + Y^N + K^N))
end
end

# Ensures that original system is not modified.
let
# Create model with a registered function.
@species X(t)
@variables V(t)
@parameters v K
eqs = [
Reaction(mm(X,v,K), [], [X]),
mm(V,v,K) ~ V + 1
]
@named rs = ReactionSystem(eqs, t)

# Check that `expand_registered_functions` does not mutate original model.
rs_expanded_funcs = Catalyst.expand_registered_functions(rs)
@test isequal(only(Catalyst.get_rxs(rs)).rate, Catalyst.mm(X,v,K))
@test isequal(only(Catalyst.get_rxs(rs_expanded_funcs)).rate, v*X/(X + K))
@test isequal(last(Catalyst.get_eqs(rs)).lhs, Catalyst.mm(V,v,K))
@test isequal(last(Catalyst.get_eqs(rs_expanded_funcs)).lhs, v*V/(V + K))
end

# Tests on model with events.
let
# Creates a model, saves it, and creates an expanded version.
rs = @reaction_network begin
@continuous_events begin
[mm(X,v,K) ~ 1.0] => [X ~ X]
end
@discrete_events begin
[1.0] => [X ~ mmr(X,v,K) + Y*(v + K)]
1.0 => [X ~ X]
(hill(X,v,K,n) > 1000.0) => [X ~ hillr(X,v,K,n) + 2]
end
v0 + hillar(X,Y,v,K,n), X --> Y
end
rs_saved = deepcopy(rs)
rs_expanded = Catalyst.expand_registered_functions(rs)

# Checks that the original model is unchanged (equality currently does not consider events).
@test rs == rs_saved
@test get_continuous_events(rs) == get_continuous_events(rs_saved)
@test get_discrete_events(rs) == get_discrete_events(rs_saved)

# Checks that the new system is expanded.
@unpack v0, X, Y, v, K, n = rs
continuous_events = [
[v*X/(X + K) ~ 1.0] => [X ~ X]
]
discrete_events = [
[1.0] => [X ~ v*K/(X + K) + Y*(v + K)]
1.0 => [X ~ X]
(v * (X^n) / (X^n + K^n) > 1000.0) => [X ~ v * (K^n) / (X^n + K^n) + 2]
]
continuous_events = ModelingToolkit.SymbolicContinuousCallback.(continuous_events)
discrete_events = ModelingToolkit.SymbolicDiscreteCallback.(discrete_events)
@test isequal(only(Catalyst.get_rxs(rs_expanded)).rate, v0 + v * (X^n) / (X^n + Y^n + K^n))
@test isequal(get_continuous_events(rs_expanded), continuous_events)
@test isequal(get_discrete_events(rs_expanded), discrete_events)
end

0 comments on commit 8f8290c

Please sign in to comment.