Skip to content

Commit

Permalink
use generated rate law function
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Nov 21, 2023
1 parent a573f0f commit f623f24
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Expand Down
4 changes: 4 additions & 0 deletions src/Catalyst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ const MT = ModelingToolkit
using Unitful
@reexport using ModelingToolkit
using Symbolics

using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

import Symbolics: BasicSymbolic
import SymbolicUtils
using ModelingToolkit: Symbolic, value, istree, get_states, get_ps, get_iv, get_systems,
Expand Down
13 changes: 7 additions & 6 deletions src/spatial_reaction_systems/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,17 @@ end
# Else a vector with each value corresponding to the rate at one specific edge.
function compute_transport_rates(rate_law::Num,
p_val_dict::Dict{SymbolicUtils.BasicSymbolic{Real}, Vector{Float64}}, num_edges::Int64)
relevant_ps = Symbolics.get_variables(rate_law)

# If all these parameters are spatially uniform. `rates` becomes a vector with 1 value.
if all(length(p_val_dict[P]) == 1 for P in relevant_ps)
rates = [substitute(rate_law, Dict(p => p_val_dict[p][1] for p in relevant_ps))]
# Finds parameters involved in rate and create a function evaluating teh rate law.
relevant_ps = Symbolics.get_variables(rate_law)
rate_law_func = drop_expr(@RuntimeGeneratedFunction(build_function(rate_law, relevant_ps...)))

# If all these parameters are spatially uniform. `rates` becomes a vector with 1 value.
if all(length(p_val_dict[P]) == 1 for P in relevant_ps)
rates = [rate_law_func([p_val_dict[p][1] for p in relevant_ps]...)]
# If at least on parameter the rate depends on have a value varying across all edges,
# we have to compute one rate value for each edge.
else
rates = [substitute(rate_law, Dict(p => get_component_value(p_val_dict[p], idxE) for p in relevant_ps))
rates = [rate_law_func([get_component_value(p_val_dict[p], idxE) for p in relevant_ps]...)
for idxE in 1:num_edges]
end
return Symbolics.value.(rates)
Expand Down

0 comments on commit f623f24

Please sign in to comment.