Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @no_infer flag for turning off species/variable/parameter inferring #1122

Merged
merged 18 commits into from
Nov 20, 2024
61 changes: 44 additions & 17 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const pure_rate_arrows = Set{Symbol}([:(=>), :(<=), :⇐, :⟽, :⇒, :⟾, :⇔
# Declares the keys used for various options.
const option_keys = (:species, :parameters, :variables, :ivs, :compounds, :observables,
:default_noise_scaling, :differentials, :equations,
:continuous_events, :discrete_events, :combinatoric_ratelaws)
:continuous_events, :discrete_events, :combinatoric_ratelaws, :require_declaration)

### `@species` Macro ###

Expand Down Expand Up @@ -220,13 +220,14 @@ struct ReactionStruct
products::Vector{ReactantStruct}
rate::ExprValues
metadata::Expr
rxexpr::Expr

function ReactionStruct(sub_line::ExprValues, prod_line::ExprValues, rate::ExprValues,
metadata_line::ExprValues)
metadata_line::ExprValues, rx_line::Expr)
sub = recursive_find_reactants!(sub_line, 1, Vector{ReactantStruct}(undef, 0))
prod = recursive_find_reactants!(prod_line, 1, Vector{ReactantStruct}(undef, 0))
metadata = extract_metadata(metadata_line)
new(sub, prod, rate, metadata)
new(sub, prod, rate, metadata, rx_line)
end
end

Expand Down Expand Up @@ -283,6 +284,17 @@ function extract_metadata(metadata_line::Expr)
return metadata
end



struct UndeclaredSymbolicError <: Exception
msg::String
end

function Base.showerror(io::IO, err::UndeclaredSymbolicError)
print(io, "UndeclaredSymbolicError: ")
print(io, err.msg)
end
isaacsas marked this conversation as resolved.
Show resolved Hide resolved

### DSL Internal Master Function ###

# Function for creating a ReactionSystem structure (used by the @reaction_network macro).
Expand All @@ -308,6 +320,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
compound_expr, compound_species = read_compound_options(options)
continuous_events_expr = read_events_option(options, :continuous_events)
discrete_events_expr = read_events_option(options, :discrete_events)
requiredec = haskey(options, :require_declaration)

# Parses reactions, species, and parameters.
reactions = get_reactions(reaction_lines)
Expand All @@ -317,7 +330,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))

# Reads equations.
vars_extracted, add_default_diff, equations = read_equations_options(
options, variables_declared)
options, variables_declared; requiredec)
variables = vcat(variables_declared, vars_extracted)

# Handle independent variables
Expand All @@ -341,13 +354,13 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))

# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs)
options, [species_declared; variables], all_ivs; requiredec)

# Collect species and parameters, including ones inferred from the reactions.
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
variables)))
species_extracted, parameters_extracted = extract_species_and_parameters!(
reactions, declared_syms)
reactions, declared_syms; requiredec)

species = vcat(species_declared, species_extracted)
parameters = vcat(parameters_declared, parameters_extracted)
Expand Down Expand Up @@ -425,15 +438,15 @@ function get_reactions(exprs::Vector{Expr}, reactions = Vector{ReactionStruct}(u
error("Error: Must provide a tuple of reaction rates when declaring a bi-directional reaction.")
end
push_reactions!(reactions, reaction.args[2], reaction.args[3],
rate.args[1], metadata.args[1], arrow)
rate.args[1], metadata.args[1], arrow, line)
push_reactions!(reactions, reaction.args[3], reaction.args[2],
rate.args[2], metadata.args[2], arrow)
rate.args[2], metadata.args[2], arrow, line)
elseif in(arrow, fwd_arrows)
push_reactions!(reactions, reaction.args[2], reaction.args[3],
rate, metadata, arrow)
rate, metadata, arrow, line)
elseif in(arrow, bwd_arrows)
push_reactions!(reactions, reaction.args[3], reaction.args[2],
rate, metadata, arrow)
rate, metadata, arrow, line)
else
throw("Malformed reaction, invalid arrow type used in: $(MacroTools.striplines(line))")
end
Expand Down Expand Up @@ -467,7 +480,7 @@ end
# Takes a reaction line and creates reaction(s) from it and pushes those to the reaction array.
# Used to create multiple reactions from, for instance, `k, (X,Y) --> 0`.
function push_reactions!(reactions::Vector{ReactionStruct}, sub_line::ExprValues,
prod_line::ExprValues, rate::ExprValues, metadata::ExprValues, arrow::Symbol)
prod_line::ExprValues, rate::ExprValues, metadata::ExprValues, arrow::Symbol, line::Expr)
# The rates, substrates, products, and metadata may be in a tupple form (e.g. `k, (X,Y) --> 0`).
# This finds the length of these tuples (or 1 if not in tuple forms). Errors if lengs inconsistent.
lengs = (tup_leng(sub_line), tup_leng(prod_line), tup_leng(rate), tup_leng(metadata))
Expand All @@ -490,7 +503,7 @@ function push_reactions!(reactions::Vector{ReactionStruct}, sub_line::ExprValues

push!(reactions,
ReactionStruct(get_tup_arg(sub_line, i),
get_tup_arg(prod_line, i), get_tup_arg(rate, i), metadata_i))
get_tup_arg(prod_line, i), get_tup_arg(rate, i), metadata_i, line))
end
end

Expand All @@ -511,20 +524,26 @@ end

# Function looping through all reactions, to find undeclared symbols (species or
# parameters), and assign them to the right category.
function extract_species_and_parameters!(reactions, excluded_syms)
function extract_species_and_parameters!(reactions, excluded_syms; requiredec = false)
species = OrderedSet{Union{Symbol, Expr}}()
for reaction in reactions
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
add_syms_from_expr!(species, reactant.reactant, excluded_syms)
(!isempty(species) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized variables $(join(species, ", ")) detected in reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all species must be explicitly declared with the @species macro."))
end
end

foreach(s -> push!(excluded_syms, s), species)
parameters = OrderedSet{Union{Symbol, Expr}}()
for reaction in reactions
add_syms_from_expr!(parameters, reaction.rate, excluded_syms)
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized parameter $(join(parameters, ", ")) detected in rate expression: $(reaction.rate) for the following reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
add_syms_from_expr!(parameters, reactant.stoichiometry, excluded_syms)
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized parameters $(join(parameters, ", ")) detected in the stoichiometry for reactant $(reactant.reactant) in the following reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
end
end

Expand Down Expand Up @@ -682,7 +701,7 @@ end
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
# `equations`: a vector with the equations provided.
function read_equations_options(options, variables_declared)
function read_equations_options(options, variables_declared; requiredec = false)
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
Expand Down Expand Up @@ -711,9 +730,13 @@ function read_equations_options(options, variables_declared)
diff_var = lhs.args[2]
if in(diff_var, forbidden_symbols_error)
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
elseif (!in(diff_var, variables_declared)) && requiredec
throw(UndeclaredSymbolicError(
"Unrecognized symbol $(diff_var) was used as a variable in an equation: \"$eq\". Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
else
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
end

Expand Down Expand Up @@ -752,7 +775,7 @@ function create_differential_expr(options, add_default_diff, used_syms, tiv)
end

# Reads the observables options. Outputs an expression ofr creating the observable variables, and a vector of observable equations.
function read_observed_options(options, species_n_vars_declared, ivs_sorted)
function read_observed_options(options, species_n_vars_declared, ivs_sorted; requiredec = false)
if haskey(options, :observables)
# Gets list of observable equations and prepares variable declaration expression.
# (`options[:observables]` includes `@observables`, `.args[3]` removes this part)
Expand All @@ -763,6 +786,10 @@ function read_observed_options(options, species_n_vars_declared, ivs_sorted)
for (idx, obs_eq) in enumerate(observed_eqs.args)
# Extract the observable, checks errors, and continues the loop if the observable has been declared.
obs_name, ivs, defaults, metadata = find_varinfo_in_declaration(obs_eq.args[2])
if (requiredec && !in(obs_name, species_n_vars_declared))
throw(UndeclaredSymbolicError(
"An undeclared variable ($obs_name) was declared as an observable in the following observable equation: \"$obs_eq\". Since the flag @require_declaration is set, all variables must be declared with the @species, @parameters, or @variables macros."))
end
isempty(ivs) ||
error("An observable ($obs_name) was given independent variable(s). These should not be given, as they are inferred automatically.")
isnothing(defaults) ||
Expand Down
70 changes: 70 additions & 0 deletions test/dsl/dsl_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1022,3 +1022,73 @@ let
@parameters v n
@test isequal(Catalyst.expand_registered_functions(equations(rn4)[1]), D(A) ~ v*(A^n))
end

### test that @no_infer properly throws errors when undeclared variables are written

import Catalyst: UndeclaredSymbolicError
let
# Test error when species are inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@parameters k
k, A --> B
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@species A(t) B(t)
@parameters k
k, A --> B
end

# Test error when a parameter in rate is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@species A(t) B(t)
@parameters k
k*n, A --> B
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@parameters n k
@species A(t) B(t)
k*n, A --> B
end

# Test error when a parameter in stoichiometry is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@parameters k
@species A(t) B(t)
k, n*A --> B
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@parameters k n
@species A(t) B(t)
k, n*A --> B
end

# Test error when a variable in an equation is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@equations D(V) ~ V^2
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@variables V(t)
@equations D(V) ~ V^2
end

# Test error when a variable in an observable is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@variables X1(t)
@observables X2 ~ X1
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@variables X1(t) X2(t)
@observables X2 ~ X1
end
end

Loading