Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop scalarizing #1052

Merged
merged 14 commits into from
Sep 21, 2024
29 changes: 19 additions & 10 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,12 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
sexprs = get_sexpr(species_extracted, options; iv_symbols = ivs)
vexprs = get_sexpr(vars_extracted, options, :variables; iv_symbols = ivs)
pexprs = get_pexpr(parameters_extracted, options)
ps, pssym = scalarize_macro(!isempty(parameters), pexprs, "ps")
vars, varssym = scalarize_macro(!isempty(variables), vexprs, "vars")
sps, spssym = scalarize_macro(!isempty(species), sexprs, "specs")
comps, compssym = scalarize_macro(!isempty(compound_species), compound_expr, "comps")
ps, pssym = assign_expr_to_var(!isempty(parameters), pexprs, "ps")
vars, varssym = assign_expr_to_var(!isempty(variables), vexprs, "vars";
scalarize = true)
sps, spssym = assign_expr_to_var(!isempty(species), sexprs, "specs"; scalarize = true)
comps, compssym = assign_expr_to_var(!isempty(compound_species), compound_expr,
"comps"; scalarize = true)
rxexprs = :(CatalystEqType[])
for reaction in reactions
push!(rxexprs.args, get_rxexprs(reaction))
Expand Down Expand Up @@ -591,14 +593,21 @@ function get_rxexprs(rxstruct)
end

# takes a ModelingToolkit declaration macro like @parameters and returns an expression
# that calls the macro and then scalarizes all the symbols created into a vector of Nums
function scalarize_macro(nonempty, ex, name)
# that calls the macro and saves it in a variable given by namesym based on name.
# scalarizes if desired
function assign_expr_to_var(nonempty, ex, name; scalarize = false)
namesym = gensym(name)
if nonempty
symvec = gensym()
ex = quote
$symvec = $ex
$namesym = reduce(vcat, Symbolics.scalarize($symvec))
if scalarize
symvec = gensym()
ex = quote
$symvec = $ex
$namesym = reduce(vcat, Symbolics.scalarize($symvec))
end
else
ex = quote
$namesym = $ex
end
end
else
ex = :($namesym = Num[])
Expand Down
4 changes: 2 additions & 2 deletions src/network_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,8 @@ function cache_conservationlaw_eqs!(rn::ReactionSystem, N::AbstractMatrix, col_o
indepspecs = sts[indepidxs]
depidxs = col_order[(r + 1):end]
depspecs = sts[depidxs]
constants = MT.unwrap.(MT.scalarize(only(
@parameters $(CONSERVED_CONSTANT_SYMBOL)[1:nullity] [conserved = true])))
constants = MT.unwrap(only(
@parameters $(CONSERVED_CONSTANT_SYMBOL)[1:nullity] [conserved = true]))

conservedeqs = Equation[]
constantdefs = Equation[]
Expand Down
34 changes: 25 additions & 9 deletions src/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,11 @@ function ReactionSystem(eqs, iv, unknowns, ps;
sivs′ = if spatial_ivs === nothing
Vector{typeof(iv′)}()
else
value.(MT.scalarize(spatial_ivs))
value.(spatial_ivs)
end
unknowns′ = sort!(value.(MT.scalarize(unknowns)), by = !isspecies)
unknowns′ = sort!(value.(unknowns), by = !isspecies)
spcs = filter(isspecies, unknowns′)
ps′ = value.(MT.scalarize(ps))
ps′ = value.(ps)

# Checks that no (by Catalyst) forbidden symbols are used.
allsyms = Iterators.flatten((ps′, unknowns′))
Expand Down Expand Up @@ -467,7 +467,7 @@ end
# Two-argument constructor (reactions/equations and time variable).
# Calls the `make_ReactionSystem_internal`, which in turn calls the four-argument constructor.
function ReactionSystem(rxs::Vector, iv = Catalyst.DEFAULT_IV; kwargs...)
make_ReactionSystem_internal(rxs, iv, Vector{Num}(), Vector{Num}(); kwargs...)
make_ReactionSystem_internal(rxs, iv, [], []; kwargs...)
end

# One-argument constructor. Creates an emtoy `ReactionSystem` from a time independent variable only.
Expand All @@ -485,16 +485,17 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
spatial_ivs = nothing, continuous_events = [], discrete_events = [],
observed = [], kwargs...)

# Filters away any potential observables from `states` and `spcs`.
obs_vars = [obs_eq.lhs for obs_eq in observed]
us_in = filter(u -> !any(isequal(u, obs_var) for obs_var in obs_vars), us_in)
Copy link
Member Author

@isaacsas isaacsas Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TorkelE I modified this as I think it makes more sense to disallow including observable variables in the unknowns, and to instead give users an error (rather than filtering them out). This also seems less likely to lead to user code bugs where a user thinks something is an unknown but it is implicitly an observable.

# Error if any observables have been declared a species or variable
obs_vars = Set(obs_eq.lhs for obs_eq in observed)
any(in(obs_vars), us_in) &&
error("Found an observable in the list of unknowns. This is not allowed.")

# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
# independent variables can be excluded when encountered quantities are added to `us` and `ps`).
t = value(iv)
ivs = Set([t])
if (spatial_ivs !== nothing)
for siv in (MT.scalarize(spatial_ivs))
for siv in (spatial_ivs)
push!(ivs, value(siv))
end
end
Expand Down Expand Up @@ -548,7 +549,22 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;

# Converts the found unknowns and parameters to vectors.
usv = collect(us)
psv = collect(ps)

new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
par = arguments(p)[begin]
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
all(par[i] in ps for i in eachindex(par))
push!(new_ps, par)
else
push!(new_ps, p)
end
else
push!(new_ps, p)
end
end
psv = collect(new_ps)

# Passes the processed input into the next `ReactionSystem` call.
ReactionSystem(fulleqs, t, usv, psv; spatial_ivs, continuous_events,
Expand Down
9 changes: 3 additions & 6 deletions test/dsl/dsl_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ let
k, 0 --> X1 + X2
end
@test isequal(observed(rn1)[1].rhs, observed(rn2)[1].rhs)
@test isequal(observed(rn1)[1].lhs.metadata, observed(rn2)[1].lhs.metadata)
@test_broken isequal(observed(rn1)[1].lhs.metadata, observed(rn2)[1].lhs.metadata)
@test isequal(unknowns(rn1), unknowns(rn2))

# Case with metadata.
Expand All @@ -618,7 +618,7 @@ let
k, 0 --> X1 + X2
end
@test isequal(observed(rn3)[1].rhs, observed(rn4)[1].rhs)
@test isequal(observed(rn3)[1].lhs.metadata, observed(rn4)[1].lhs.metadata)
@test_broken isequal(observed(rn3)[1].lhs.metadata, observed(rn4)[1].lhs.metadata)
@test isequal(unknowns(rn3), unknowns(rn4))
end

Expand Down Expand Up @@ -822,10 +822,7 @@ let
@variables X(t)
@equations 2X ~ $c - X
end)

u0 = [rn.X => 0.0]
ps = []
oprob = ODEProblem(rn, u0, (0.0, 100.0), ps; structural_simplify=true)
oprob = ODEProblem(rn, [], (0.0, 100.0); structural_simplify=true)
sol = solve(oprob, Tsit5(); abstol=1e-9, reltol=1e-9)
@test sol[rn.X][end] ≈ 2.0
end
Expand Down
33 changes: 15 additions & 18 deletions test/reactionsystem_core/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,18 @@ rxs = [Reaction(k[1], nothing, [A]), # 0 -> A
Reaction(k[19] * t, [A], [B]), # A -> B with non constant rate.
Reaction(k[20] * t * A, [B, C], [D], [2, 1], [2]), # 2A +B -> 2C with non constant rate.
]
@named rs = ReactionSystem(rxs, t, [A, B, C, D], k)
@named rs = ReactionSystem(rxs, t, [A, B, C, D], [k])
rs = complete(rs)
odesys = complete(convert(ODESystem, rs))
sdesys = complete(convert(SDESystem, rs))

# Hard coded ODE rhs.
function oderhs(u, k, t)
function oderhs(u, kv, t)
A = u[1]
B = u[2]
C = u[3]
D = u[4]
k = kv[1]
du = zeros(eltype(u), 4)
du[1] = k[1] - k[3] * A + k[4] * C + 2 * k[5] * C - k[6] * A * B + k[7] * B^2 / 2 -
k[9] * A * B - k[10] * A^2 - k[11] * A^2 / 2 - k[12] * A * B^3 * C^4 / 144 -
Expand All @@ -68,11 +69,12 @@ function oderhs(u, k, t)
end

# SDE noise coefs.
function sdenoise(u, k, t)
function sdenoise(u, kv, t)
A = u[1]
B = u[2]
C = u[3]
D = u[4]
k = kv[1]
G = zeros(eltype(u), length(k), length(u))
z = zero(eltype(u))

Expand Down Expand Up @@ -109,11 +111,12 @@ end

# Defaults test.
let
def_p = [ki => float(i) for (i, ki) in enumerate(k)]
kvals = Float64.(1:length(k))
def_p = [k => kvals]
def_u0 = [A => 0.5, B => 1.0, C => 1.5, D => 2.0]
defs = merge(Dict(def_p), Dict(def_u0))

@named rs = ReactionSystem(rxs, t, [A, B, C, D], k; defaults = defs)
@named rs = ReactionSystem(rxs, t, [A, B, C, D], [k]; defaults = defs)
rs = complete(rs)
odesys = complete(convert(ODESystem, rs))
sdesys = complete(convert(SDESystem, rs))
Expand All @@ -126,15 +129,11 @@ let
defs

u0map = [A => 5.0]
pmap = [k[1] => 5.0]
kvals[1] = 5.0
pmap = [k => kvals]
prob = ODEProblem(rs, u0map, (0, 10.0), pmap)
@test prob.ps[k[1]] == 5.0
@test prob.u0[1] == 5.0
u0 = [10.0, 11.0, 12.0, 13.0]
ps = [float(x) for x in 100:119]
prob = ODEProblem(rs, u0, (0, 10.0), ps)
@test [prob.ps[k[i]] for i in 1:20] == ps
@test prob.u0 == u0
Comment on lines -133 to -137
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not appropriate inputs since they aren't mappings, hence I removed them.

end

### Check ODE, SDE, and Jump Functions ###
Expand Down Expand Up @@ -181,7 +180,7 @@ let
Reaction(k[19] * t, [D], [E]), # D -> E with non constant rate.
Reaction(k[20] * t * A, [D, E], [F], [2, 1], [2]), # 2D + E -> 2F with non constant rate.
]
@named rs = ReactionSystem(rxs, t, [A, B, C, D, E, F], k)
@named rs = ReactionSystem(rxs, t, [A, B, C, D, E, F], [k])
rs = complete(rs)
js = complete(convert(JumpSystem, rs))

Expand All @@ -193,7 +192,7 @@ let
@test all(map(i -> typeof(equations(js)[i]) <: JumpProcesses.VariableRateJump, vidxs))

p = rand(rng, length(k))
pmap = parameters(js) .=> p
pmap = [k => p]
u0 = rand(rng, 2:10, 6)
u0map = unknowns(js) .=> u0
ttt = rand(rng)
Expand Down Expand Up @@ -868,11 +867,9 @@ end
let
@species (A(t))[1:20]
using ModelingToolkit: value
@test isspecies(value(A))
@test isspecies(value(A[2]))
Av = value.(ModelingToolkit.scalarize(A))
@test isspecies(Av[2])
@test isequal(value(Av[2]), value(A[2]))
Av = value(A)
@test isspecies(Av)
@test all(i -> isspecies(Av[i]), 1:length(Av))
end

# Test mixed models are formulated correctly.
Expand Down
8 changes: 4 additions & 4 deletions test/test_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ using Random, Test

# Generates a random initial condition (in the form of a map). Each value is a Float64.
function rnd_u0(sys, rng; factor = 1.0, min = 0.0)
return [u => min + factor * rand(rng) for u in unknowns(sys)]
return [u => (min .+ factor .* rand(rng, size(u)...)) for u in unknowns(sys)]
end

# Generates a random initial condition (in the form of a map). Each value is a Int64.
function rnd_u0_Int64(sys, rng; n = 5, min = 0)
return [u => min + rand(rng, 1:n) for u in unknowns(sys)]
return [u => (min .+ rand(rng, 1:n, size(u)...)) for u in unknowns(sys)]
end

# Generates a random parameter set (in the form of a map). Each value is a Float64.
function rnd_ps(sys, rng; factor = 1.0, min = 0.0)
return [p => min + factor * rand(rng) for p in parameters(sys)]
return [p => ( min .+ factor .* rand(rng, size(p)...)) for p in parameters(sys)]
end

# Generates a random parameter set (in the form of a map). Each value is a Float64.
function rnd_ps_Int64(sys, rng; n = 5, min = 0)
return [p => min + rand(rng, 1:n) for p in parameters(sys)]
return [p => (min .+ rand(rng, 1:n, size(p)...)) for p in parameters(sys)]
end

# Used to convert a generated initial condition/parameter set to a vector that can be used for normal
Expand Down
Loading