Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @no_infer flag for turning off species/variable/parameter inferring #1122

Merged
merged 18 commits into from
Nov 20, 2024
2 changes: 1 addition & 1 deletion src/Catalyst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ export symmap_to_varmap
# reaction_network macro
include("expression_utils.jl")
include("dsl.jl")
export @reaction_network, @network_component, @reaction, @species
export @reaction_network, @network_component, @reaction, @species, UndeclaredSymbolicError
vyudu marked this conversation as resolved.
Show resolved Hide resolved

# Network analysis functionality.
include("network_analysis.jl")
Expand Down
47 changes: 38 additions & 9 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const pure_rate_arrows = Set{Symbol}([:(=>), :(<=), :⇐, :⟽, :⇒, :⟾, :⇔
# Declares the keys used for various options.
const option_keys = (:species, :parameters, :variables, :ivs, :compounds, :observables,
:default_noise_scaling, :differentials, :equations,
:continuous_events, :discrete_events, :combinatoric_ratelaws)
:continuous_events, :discrete_events, :combinatoric_ratelaws, :require_declaration)

### `@species` Macro ###

Expand Down Expand Up @@ -230,6 +230,9 @@ struct ReactionStruct
end
end

#function Base.show(io::IO, rx::ReactionStruct) #
#end

# Recursive function that loops through the reaction line and finds the reactants and their
# stoichiometry. Recursion makes it able to handle weird cases like 2(X+Y+3(Z+XY)).
function recursive_find_reactants!(ex::ExprValues, mult::ExprValues,
Expand Down Expand Up @@ -283,6 +286,17 @@ function extract_metadata(metadata_line::Expr)
return metadata
end



struct UndeclaredSymbolicError <: Exception
msg::String
end

function Base.showerror(io::IO, err::UndeclaredSymbolicError)
print(io, "UndeclaredSymbolicError: ")
print(io, err.msg)
end
isaacsas marked this conversation as resolved.
Show resolved Hide resolved

### DSL Internal Master Function ###

# Function for creating a ReactionSystem structure (used by the @reaction_network macro).
Expand All @@ -308,6 +322,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
compound_expr, compound_species = read_compound_options(options)
continuous_events_expr = read_events_option(options, :continuous_events)
discrete_events_expr = read_events_option(options, :discrete_events)
requiredec = haskey(options, :require_declaration)

# Parses reactions, species, and parameters.
reactions = get_reactions(reaction_lines)
Expand All @@ -317,7 +332,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))

# Reads equations.
vars_extracted, add_default_diff, equations = read_equations_options(
options, variables_declared)
options, variables_declared; requiredec)
variables = vcat(variables_declared, vars_extracted)

# Handle independent variables
Expand All @@ -341,13 +356,13 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))

# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs)
options, [species_declared; variables], all_ivs; requiredec)

# Collect species and parameters, including ones inferred from the reactions.
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
variables)))
species_extracted, parameters_extracted = extract_species_and_parameters!(
reactions, declared_syms)
reactions, declared_syms; requiredec)

species = vcat(species_declared, species_extracted)
parameters = vcat(parameters_declared, parameters_extracted)
Expand Down Expand Up @@ -511,20 +526,26 @@ end

# Function looping through all reactions, to find undeclared symbols (species or
# parameters), and assign them to the right category.
function extract_species_and_parameters!(reactions, excluded_syms)
function extract_species_and_parameters!(reactions, excluded_syms; requiredec = false)
species = OrderedSet{Union{Symbol, Expr}}()
for reaction in reactions
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
add_syms_from_expr!(species, reactant.reactant, excluded_syms)
(!isempty(species) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized variables $(species[1]) detected in reaction expression. Since the flag @require_declaration is declared, all species must be explicitly declared with the @species macro."))
end
end

foreach(s -> push!(excluded_syms, s), species)
parameters = OrderedSet{Union{Symbol, Expr}}()
for reaction in reactions
add_syms_from_expr!(parameters, reaction.rate, excluded_syms)
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized parameter $(parameters[1]) detected in rate expression $(reaction.rate). Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
add_syms_from_expr!(parameters, reactant.stoichiometry, excluded_syms)
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized parameters $(parameters[1]) detected in the stoichiometry for reactant $(reactant.reactant). Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
end
end

Expand Down Expand Up @@ -682,7 +703,7 @@ end
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
# `equations`: a vector with the equations provided.
function read_equations_options(options, variables_declared)
function read_equations_options(options, variables_declared; requiredec = false)
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
Expand Down Expand Up @@ -711,9 +732,13 @@ function read_equations_options(options, variables_declared)
diff_var = lhs.args[2]
if in(diff_var, forbidden_symbols_error)
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
elseif (!in(diff_var, variables_declared)) && requiredec
throw(UndeclaredSymbolicError(
"Unrecognized symbol $(diff_var) was used as a variable in an equation. Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
else
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
end

Expand Down Expand Up @@ -752,7 +777,7 @@ function create_differential_expr(options, add_default_diff, used_syms, tiv)
end

# Reads the observables options. Outputs an expression ofr creating the observable variables, and a vector of observable equations.
function read_observed_options(options, species_n_vars_declared, ivs_sorted)
function read_observed_options(options, species_n_vars_declared, ivs_sorted; requiredec = false)
if haskey(options, :observables)
# Gets list of observable equations and prepares variable declaration expression.
# (`options[:observables]` includes `@observables`, `.args[3]` removes this part)
Expand All @@ -763,6 +788,10 @@ function read_observed_options(options, species_n_vars_declared, ivs_sorted)
for (idx, obs_eq) in enumerate(observed_eqs.args)
# Extract the observable, checks errors, and continues the loop if the observable has been declared.
obs_name, ivs, defaults, metadata = find_varinfo_in_declaration(obs_eq.args[2])
if (requiredec && !in(obs_name, species_n_vars_declared))
throw(UndeclaredSymbolicError(
"An undeclared variable ($obs_name) was declared as an observable. Since the flag @require_declaration is set, all variables must be declared with the @species, @parameters, or @variables macros."))
end
isempty(ivs) ||
error("An observable ($obs_name) was given independent variable(s). These should not be given, as they are inferred automatically.")
isnothing(defaults) ||
Expand Down
69 changes: 69 additions & 0 deletions test/dsl/dsl_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1022,3 +1022,72 @@ let
@parameters v n
@test isequal(Catalyst.expand_registered_functions(equations(rn4)[1]), D(A) ~ v*(A^n))
end

### test that @no_infer properly throws errors when undeclared variables are written

let
# Test error when species are inferred
@test_throws UndeclaredSymbolicError @macroexpand rn = @reaction_network begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, rn = is not needed in these cases

@require_declaration
@parameters k
k, A --> B
end
@test_nowarn @macroexpand rn = @reaction_network begin
@require_declaration
@species A(t) B(t)
@parameters k
k, A --> B
end

# Test error when a parameter in rate is inferred
@test_throws UndeclaredSymbolicError @macroexpand rn = @reaction_network begin
@require_declaration
@species A(t) B(t)
@parameters k
k*n, A --> B
end
@test_nowarn @macroexpand rn = @reaction_network begin
@require_declaration
@parameters n k
@species A(t) B(t)
k*n, A --> B
end

# Test error when a parameter in stoichiometry is inferred
@test_throws UndeclaredSymbolicError @macroexpand rn = @reaction_network begin
@require_declaration
@parameters k
@species A(t) B(t)
k, n*A --> B
end
@test_nowarn @macroexpand rn = @reaction_network begin
@require_declaration
@parameters k n
@species A(t) B(t)
k, n*A --> B
end

# Test error when a variable in an equation is inferred
@test_throws UndeclaredSymbolicError @macroexpand rn = @reaction_network begin
@require_declaration
@equations D(V) ~ V^2
end
@test_nowarn @macroexpand rn = @reaction_network begin
@require_declaration
@variables V(t)
@equations D(V) ~ V^2
end

# Test error when a variable in an observable is inferred
@test_throws UndeclaredSymbolicError @macroexpand rn = @reaction_network begin
@require_declaration
@variables X1(t)
@observables X2 ~ X1
end
@test_nowarn @macroexpand rn = @reaction_network begin
@require_declaration
@variables X1(t) X2(t)
@observables X2 ~ X1
end
end

Loading