Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed May 3, 2023
1 parent 630b72b commit 862ddfe
Showing 1 changed file with 48 additions and 35 deletions.
83 changes: 48 additions & 35 deletions src/spatial_reaction_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ end
# Creates an ODEProblem from a LatticeReactionSystem.
function DiffEqBase.ODEProblem(lrs::LatticeReactionSystem, u0, tspan,
p = DiffEqBase.NullParameters(), args...;
sparse=true, kwargs...)
jac=true, sparse=true, kwargs...)
@unpack rs, spatial_reactions, lattice = lrs

spatial_params = unique(getfield.(spatial_reactions, :rate))
Expand All @@ -109,7 +109,7 @@ function DiffEqBase.ODEProblem(lrs::LatticeReactionSystem, u0, tspan,
pV = matrix_form(pV_in, nV, pV_idxes)
pE = matrix_form(pE_in, nE, pE_idxes)

ofun = build_odefunction(lrs, sparse, spatial_params)
ofun = build_odefunction(lrs, spatial_params, jac, sparse)
return ODEProblem(ofun, u0, tspan, (pV, pE), args...; kwargs...)
end

Expand All @@ -131,34 +131,34 @@ function matrix_form(input::Vector, n, index_dict)
end

# Builds an ODEFunction.
function build_odefunction(lrs::LatticeReactionSystem, sparse::Bool, spatial_params::Vector{Symbol})
function build_odefunction(lrs::LatticeReactionSystem, spatial_params::Vector{Symbol}, use_jac::Bool, sparse::Bool)
ofunc = ODEFunction(convert(ODESystem, lrs.rs); jac=true)
nS,nV = length.([states(lrs.rs), vertices(lrs.lattice)])
spatial_reactions = [SpatialReactionIndexed(sr, Symbol.(getfield.(states(lrs.rs), :f)), spatial_params) for sr in lrs.spatial_reactions]

f = build_f(ofunc, nS, nV, spatial_reactions, lrs.lattice.fadjlist)
jac = build_jac(ofunc,nS,nV,spatial_reactions, lrs.lattice.fadjlist)
jac_prototype = build_jac_prototype(nS,nV,spatial_reactions, lrs)
jac_prototype = build_jac_prototype(nS,nV,spatial_reactions, lrs.lattice.fadjlist)

return ODEFunction(f; jac=jac, jac_prototype=(sparse ? SparseArrays.sparse(jac_prototype) : jac_prototype))
return ODEFunction(f; jac=(use_jac ? jac : nothing), jac_prototype=(use_jac ? (sparse ? SparseArrays.sparse(jac_prototype) : jac_prototype) : nothing))
end

# Creates a function for simulating the spatial ODE with spatial reactions.
function build_f(ofunc::SciMLBase.AbstractODEFunction{true}, nS::Int64, nV::Int64, spatial_reactions::Vector{SpatialReactionIndexed}, adjlist::Vector{Vector{Int64}})
return function(du, u, p, t)
# Updates for non-spatial reactions.
for comp_i in 1:nV
for comp_i::Int64 in 1:nV
ofunc((@view du[get_indexes(comp_i,nS)]), (@view u[get_indexes(comp_i,nS)]), (@view p[1][:,comp_i]), t)
end

# Updates for spatial reactions.
for comp_i in 1:nV
for comp_j in adjlist[comp_i], sr in spatial_reactions
rate = get_rate(sr, p[2], (@view u[get_indexes(comp_i,nS)]), (@view u[get_indexes(comp_j,nS)]))
for stoich in sr.netstoich[1]
for comp_i::Int64 in 1:nV
for comp_j::Int64 in adjlist[comp_i], sr::SpatialReactionIndexed in spatial_reactions
rate::Float64 = get_rate(sr, p[2], (@view u[get_indexes(comp_i,nS)]), (@view u[get_indexes(comp_j,nS)]))
for stoich::Pair{Int64,Int64} in sr.netstoich[1]
du[get_index(comp_i,stoich[1],nS)] += rate * stoich[2]
end
for stoich in sr.netstoich[2]
for stoich::Pair{Int64,Int64} in sr.netstoich[2]
du[get_index(comp_j,stoich[1],nS)] += rate * stoich[2]
end
end
Expand All @@ -168,11 +168,11 @@ end

# Get the rate of a specific reaction.
function get_rate(sr::SpatialReactionIndexed, pE::Matrix{Float64}, u_src, u_dst)
product = pE[sr.rate]
!isempty(sr.substrates[1]) && for (sub,stoich) in zip(sr.substrates[1], sr.substoich[1])
product::Float64 = pE[sr.rate]
!isempty(sr.substrates[1]) && for (sub::Int64,stoich::Int64) in zip(sr.substrates[1], sr.substoich[1])
product *= u_src[sub]^stoich / factorial(stoich)
end
!isempty(sr.substrates[2]) && for (sub,stoich) in zip(sr.substrates[2], sr.substoich[2])
!isempty(sr.substrates[2]) && for (sub::Int64,stoich::Int64) in zip(sr.substrates[2], sr.substoich[2])
product *= u_dst[sub]^stoich / factorial(stoich)
end
return product
Expand All @@ -184,28 +184,28 @@ function build_jac(ofunc::SciMLBase.AbstractODEFunction{true}, nS::Int64, nV::In
J .= base_zero

# Updates for non-spatial reactions.
for comp_i in 1:nV
for comp_i::Int64 in 1:nV
ofunc.jac((@view J[get_indexes(comp_i,nS),get_indexes(comp_i,nS)]), (@view u[get_indexes(comp_i,nS)]), (@view p[1][:,comp_i]), t)
end

# Updates for spatial reactions.
for comp_i in 1:nV
for comp_j in adjlist[comp_i], sr in spatial_reactions
for sub in sr.substrates[1]
rate = get_rate_differential(sr, p[2], sub, (@view u[get_indexes(comp_i,nS)]), (@view u[get_indexes(comp_j,nS)]))
for stoich in sr.netstoich[1]
for comp_i::Int64 in 1:nV
for comp_j::Int64 in adjlist[comp_i], sr::SpatialReactionIndexed in spatial_reactions
for sub::Int64 in sr.substrates[1]
rate::Float64 = get_rate_differential(sr, p[2], sub, (@view u[get_indexes(comp_i,nS)]), (@view u[get_indexes(comp_j,nS)]))
for stoich::Pair{Int64,Int64} in sr.netstoich[1]
J[get_index(comp_i,stoich[1],nS),get_index(comp_i,sub,nS)] += rate * stoich[2]
end
for stoich in sr.netstoich[2]
for stoich::Pair{Int64,Int64} in sr.netstoich[2]
J[get_index(comp_j,stoich[1],nS),get_index(comp_i,sub,nS)] += rate * stoich[2]
end
end
for sub in sr.substrates[2]
rate = get_rate_differential(sr, p[2], sub, (@view u[get_indexes(comp_j,nS)]), (@view u[get_indexes(comp_i,nS)]))
for stoich in sr.netstoich[1]
for sub::Int64 in sr.substrates[2]
rate::Float64 = get_rate_differential(sr, p[2], sub, (@view u[get_indexes(comp_j,nS)]), (@view u[get_indexes(comp_i,nS)]))
for stoich::Pair{Int64,Int64} in sr.netstoich[1]
J[get_index(comp_i,stoich[1],nS),get_index(comp_j,sub,nS)] += rate * stoich[2]
end
for stoich in sr.netstoich[2]
for stoich::Pair{Int64,Int64} in sr.netstoich[2]
J[get_index(comp_j,stoich[1],nS),get_index(comp_j,sub,nS)] += rate * stoich[2]
end
end
Expand All @@ -215,25 +215,27 @@ function build_jac(ofunc::SciMLBase.AbstractODEFunction{true}, nS::Int64, nV::In
end

function get_rate_differential(sr::SpatialReactionIndexed, pE::Matrix{Float64}, diff_species::Int64, u_src, u_dst)
product = pE[sr.rate]
!isempty(sr.substrates[1]) && for (sub,stoich) in zip(sr.substrates[1], sr.substoich[1])
product::Float64 = pE[sr.rate]
!isempty(sr.substrates[1]) && for (sub::Int64,stoich::Int64) in zip(sr.substrates[1], sr.substoich[1])
if diff_species==sub
product *= stoich*u_src[sub]^(stoich-1) / factorial(stoich)
else
product *= u_src[sub]^stoich / factorial(stoich)
end
end
!isempty(sr.substrates[2]) && for (sub,stoich) in zip(sr.substrates[2], sr.substoich[2])
!isempty(sr.substrates[2]) && for (sub::Int64,stoich::Int64) in zip(sr.substrates[2], sr.substoich[2])
product *= u_dst[sub]^stoich / factorial(stoich)
end
return product
end

function build_jac_prototype(nS::Int64, nV::Int64, spatial_reactions::Vector{SpatialReactionIndexed}, adjlist::Vector{Vector{Int64}})
jac_prototype = zeros(nS*nV,nS*nV)
jac_prototype::Matrix{Float64} = zeros(nS*nV,nS*nV)

# Sets non-spatial reactions.
foreach(comp_i -> jac_prototype[get_indexes(comp_i,nS),get_indexes(comp_i,nS)] = ones(nS,nS), 1:nV)
for comp_i::Int64 in 1:nV
jac_prototype[get_indexes(comp_i,nS),get_indexes(comp_i,nS)] = ones(nS,nS)
end
# Tries to utilise sparsity within each comaprtment, currently not working, not sure if useful.
# for comp_i in 1:nV, reaction in reactions(lrs.rs)
# for substrate in reaction.substrates, ns in reaction.netstoich
Expand All @@ -244,11 +246,22 @@ function build_jac_prototype(nS::Int64, nV::Int64, spatial_reactions::Vector{Spa
# end

# Updates for spatial reactions.
for comp_i in 1:nV
for comp_j in adjlist[comp_i], sr in spatial_reactions
for (idx1, comp1) in [(1,comp_i),(2,comp_j)], sub in sr.substrates[idx1]
for (idx2, comp2) in [(1,comp_i),(2,comp_j)], stoich in sr.netstoich[idx2]
jac_prototype[get_index(comp2,stoich[1],nS),get_index(comp1,sub,nS)] = 1
for comp_i::Int64 in 1:nV
for comp_j::Int64 in adjlist[comp_i], sr::SpatialReactionIndexed in spatial_reactions
for sub::Int64 in sr.substrates[1]
for stoich::Pair{Int64,Int64} in sr.netstoich[1]
jac_prototype[get_index(comp_i,stoich[1],nS),get_index(comp_i,sub,nS)] = 1.0
end
for stoich::Pair{Int64,Int64} in sr.netstoich[2]
jac_prototype[get_index(comp_j,stoich[1],nS),get_index(comp_i,sub,nS)] = 1.0
end
end
for sub::Int64 in sr.substrates[2]
for stoich::Pair{Int64,Int64} in sr.netstoich[1]
jac_prototype[get_index(comp_i,stoich[1],nS),get_index(comp_j,sub,nS)] = 1.0
end
for stoich::Pair{Int64,Int64} in sr.netstoich[2]
jac_prototype[get_index(comp_j,stoich[1],nS),get_index(comp_j,sub,nS)] = 1.0
end
end
end
Expand Down

0 comments on commit 862ddfe

Please sign in to comment.