Skip to content

Commit

Permalink
rebase update
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Dec 4, 2023
1 parent d3ab89a commit 40274e6
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 31 deletions.
2 changes: 1 addition & 1 deletion ext/CatalystStructuralIdentifiabilityExtension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module CatalystStructuralIdentifiabilityExtension

# Fetch packages.
using Catalyst
import StructuralIdentifiability
import StructuralIdentifiability as SI

# Creates and exports hc_steady_states function.
include("CatalystStructuralIdentifiabilityExtension/structural_identifiability_extension.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,83 +18,127 @@ rs = @reaction_network begin
(p,d), 0 <--> X
end
si_ode(rs; measured_quantities = [:X], known_p = [:p])
Notes:
This function is part of the StructuralIdentifiability.jl extension. StructuralIdentifiability.jl must be imported to access it.
```
"""
function Catalyst.make_si_ode(rs::ReactionSystem; measured_quantities = [], known_p = [], ignore_no_measured_warn=false, remove_conserved = true)
function Catalyst.make_si_ode(rs::ReactionSystem; measured_quantities = [], known_p = [],
ignore_no_measured_warn=false, remove_conserved = true)
# Creates a MTK ODESystem, and a list of measured quantities (there are equations).
# Gives these to SI to create an SI ode model of its preferred form.
osys, conseqs, _ = make_osys(rs; remove_conserved)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p, conseqs; ignore_no_measured_warn)
return StructuralIdentifiability.preprocess_ode(osys, measured_quantities)[1]
return SI.preprocess_ode(osys, measured_quantities)[1]
end

### Structural Identifiability Wrappers ###

# Local identifiability.
function StructuralIdentifiability.assess_local_identifiability(rs::ReactionSystem, args...; measured_quantities = Num[], known_p = Num[], funcs_to_check = Vector(), remove_conserved = true, ignore_no_measured_warn=false, kwargs...)
# Creates dispatch for SI's local identifiability analysis function.
function SI.assess_local_identifiability(rs::ReactionSystem, args...; measured_quantities = Num[],
known_p = Num[], funcs_to_check = Vector(), remove_conserved = true,
ignore_no_measured_warn=false, kwargs...)
# Creates a ODESystem, list of measured quantities, and functions to check, of SI's preferred form.
osys, conseqs, vars = make_osys(rs; remove_conserved)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p, conseqs; ignore_no_measured_warn)
funcs_to_check = make_ftc(funcs_to_check, conseqs, vars)
out = StructuralIdentifiability.assess_local_identifiability(osys, args...; measured_quantities, funcs_to_check, kwargs...)

# Computes identifiability and converts it to a easy to read form.
out = SI.assess_local_identifiability(osys, args...; measured_quantities, funcs_to_check, kwargs...)
return make_output(out, funcs_to_check, reverse.(conseqs))
end

# Global identifiability.
function StructuralIdentifiability.assess_identifiability(rs::ReactionSystem, args...; measured_quantities = Num[], known_p = Num[], funcs_to_check = Vector(), remove_conserved = true, ignore_no_measured_warn=false, kwargs...)
# Creates dispatch for SI's global identifiability analysis function.
function SI.assess_identifiability(rs::ReactionSystem, args...; measured_quantities = Num[], known_p = Num[],
funcs_to_check = Vector(), remove_conserved = true, ignore_no_measured_warn=false,
kwargs...)
# Creates a ODESystem, list of measured quantities, and functions to check, of SI's preferred form.
osys, conseqs, vars = make_osys(rs; remove_conserved)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p, conseqs; ignore_no_measured_warn)
funcs_to_check = make_ftc(funcs_to_check, conseqs, vars)
out = StructuralIdentifiability.assess_identifiability(osys, args...; measured_quantities, funcs_to_check, kwargs...)

# Computes identifiability and converts it to a easy to read form.
out = SI.assess_identifiability(osys, args...; measured_quantities, funcs_to_check, kwargs...)
return make_output(out, funcs_to_check, reverse.(conseqs))
end

# Identifiable functions.
function StructuralIdentifiability.find_identifiable_functions(rs::ReactionSystem, args...; measured_quantities = Num[], known_p = Num[], remove_conserved = true, ignore_no_measured_warn=false, kwargs...)
osys, conseqs, vars = make_osys(rs; remove_conserved)
# Creates dispatch for SI's function to find all identifiable functions.
function SI.find_identifiable_functions(rs::ReactionSystem, args...; measured_quantities = Num[],
known_p = Num[], remove_conserved = true, ignore_no_measured_warn=false,
kwargs...)
# Creates a ODESystem, and list of measured quantities, of SI's preferred form.
osys, conseqs = make_osys(rs; remove_conserved)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p, conseqs; ignore_no_measured_warn)
out = StructuralIdentifiability.find_identifiable_functions(osys, args...; measured_quantities, kwargs...)

# Computes identifiable functions and converts it to a easy to read form.
out = SI.find_identifiable_functions(osys, args...; measured_quantities, kwargs...)
return vector_subs(out, reverse.(conseqs))
end

### Helper Functions ###

# From a reaction system, creates the corresponding ODESystem for SI application (and also compute the, later needed, conservation law equations and list of system symbols).
# From a reaction system, creates the corresponding MTK-style ODESystem for SI application
# Also compute the, later needed, conservation law equations and list of system symbols (states and parameters).
function make_osys(rs::ReactionSystem; remove_conserved=true)
rs = Catalyst.expand_registered_functions(rs)
# Creates the ODESystem corresponding to the ReactionSystem (expanding functions and flattening it).
# Creates a list of the systems all symbols (states and parameters).
rs = Catalyst.expand_registered_functions(flatten(rs))
osys = convert(ODESystem, rs; remove_conserved)
vars = [states(rs); parameters(rs)]

# Fixes conservation law equations. These cannot be computed for hierarchical systems (and hence this is skipped). If none is found, still have to put on the right form.
if !isempty(Catalyst.get_systems(rs)) || !remove_conserved
# Computes equations for system conservation laws.
# These cannot be computed for hierarchical systems (and hence this is skipped).
# If there are no conserved equations, the `conseqs` variable must still have the `Vector{Pair{Any, Any}}` type.
if !remove_conserved
conseqs = Vector{Pair{Any, Any}}[]
else
conseqs = [ceq.lhs => ceq.rhs for ceq in conservedequations(rs)]
isempty(conseqs) && (conseqs = Vector{Pair{Any, Any}}[])
end

return osys, conseqs, vars
end

# For input measured quantities, if this is not a vector of equations, convert it to a proper form.
function make_measured_quantities(rs::ReactionSystem, measured_quantities::Vector{T}, known_p::Vector{S}, conseqs; ignore_no_measured_warn=false) where {T,S}
ignore_no_measured_warn || isempty(measured_quantities) && @warn "No measured quantity provided to the `measured_quantities` argument, any further identifiability analysis will likely fail. You can disable this warning by setting `ignore_no_measured_warn=true`."
all_quantities = [measured_quantities; known_p]
all_quantities = [(quant isa Symbol) ? Catalyst._symbol_to_var(rs, quant) : quant for quant in all_quantities]
all_quantities = vector_subs(all_quantities, conseqs)
@variables t (___internal_observables(t))[1:length(all_quantities)]
return Equation[(all_quantities[i] isa Equation) ? all_quantities[i] : (___internal_observables[i] ~ all_quantities[i]) for i in 1:length(all_quantities)]
# Creates a list of measured quantities of a form that SI can read.
# Each measured quantity must have a form like:
# `obs_var ~ X` # (Here, `obs_var` is a variable, and X is whatever we can measure).
function make_measured_quantities(rs::ReactionSystem, measured_quantities::Vector{T}, known_p::Vector{S},
conseqs; ignore_no_measured_warn=false) where {T,S}
# Warning if the user didn't give any measured quantities.
if ignore_no_measured_warn || isempty(measured_quantities)
@warn "No measured quantity provided to the `measured_quantities` argument, any further identifiability analysis will likely fail. You can disable this warning by setting `ignore_no_measured_warn=true`."
end

# Appends the known parameters to the measured_quantities vector. Converts any Symbols to symbolics.
measured_quantities = [measured_quantities; known_p]
measured_quantities = [(q isa Symbol) ? Catalyst._symbol_to_var(rs, q) : q for q in measured_quantities]
measured_quantities = vector_subs(measured_quantities, conseqs)

# Creates one internal observation variable for each measured quantity (`___internal_observables`).
# Creates a vector of equations, setting each measured quantity equal to one observation variable.
@variables t (___internal_observables(t))[1:length(measured_quantities)]
return Equation[(q isa Equation) ? q : (___internal_observables[i] ~ q) for (i,q) in enumerate(measured_quantities)]
end

# Creates the functions that we wish to check for identifiability (if none give, by default, a list of parameters and species). Also replaces conservation law equations in.
# Creates the functions that we wish to check for identifiability.
# If no `funcs_to_check` are given, defaults to checking identifiability for all states and parameters.
# Also, for conserved equations, replaces these in (creating a system without conserved quantities).
# E.g. for `X1 <--> X2`, replaces `X2` with `Γ[1] - X2`.
# Removing conserved quantities makes SI's algorithms much more performant.
function make_ftc(funcs_to_check, conseqs, vars)
isempty(funcs_to_check) && (funcs_to_check = vars)
return vector_subs(funcs_to_check, conseqs)
end

# Replaces conservation law equations back in the output, and also sorts it according to their input order (defaults to [states; parameters] order).
# Processes the outputs to a better form.
# Replaces conservation law equations back in the output (so that e.g. Γ are not displayed).
# Sorts the output according to their input order (defaults to the `[states; parameters]` order).
function make_output(out, funcs_to_check, conseqs)
funcs_to_check = vector_subs(funcs_to_check, conseqs)
out = Dict(zip(vector_subs(keys(out), conseqs), values(out)))
out = sort(out; by = x -> findfirst(isequal(x, ftc) for ftc in funcs_to_check))
return out
end

# For a vector of expressions and a conservation law, replaces the law in.
# For a vector of expressions and a conservation law, substitutes the law into every equation.
vector_subs(eqs, subs) = [substitute(eq, subs) for eq in eqs]
56 changes: 54 additions & 2 deletions test/extensions/structural_identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ let
end

# Tests on a made-up reaction network with mix of identifiable and non-identifiable components.
# Tests for system with conserved quantity.
# Tests for symbolics known_p
# Tests using an equation for measured quantity.
let
Expand Down Expand Up @@ -184,6 +185,58 @@ let
make_si_ode(gw_osc_complt; measured_quantities=[gw_osc_complt.M*gw_osc_complt.E])
end

# Tests for hierarchical model with conservation laws at both top and internal levels.
let
# Identifiability analysis for Catalyst model.
rs1 = @reaction_network rn1 begin
(k1, k2), X1 <--> X2
end
rs2 = @reaction_network rn2 begin
(k3, k4), X3 <--> X4
end
@named rs_catalyst = flatten(compose(rs1, [rs2]))
@unpack X1, X2, k1, k2 = rn1
gi_1 = assess_identifiability(rs_catalyst; measured_quantities=[X1, X2, rs2.X3], known_p=[k1])
li_1 = assess_local_identifiability(rs_catalyst; measured_quantities=[X1, X2, rs2.X3], known_p=[k1])
ifs_1 = find_identifiable_functions(rs_catalyst; measured_quantities=[X1, X2, rs2.X3], known_p=[k1])

# Identifiability analysis for Catalyst converted to StructuralIdentifiability.jl model.
rs_ode = make_si_ode(rs_catalyst; measured_quantities=[X1, X2, rs2.X3], known_p=[k1])
gi_2 = assess_identifiability(rs_ode)
li_2 = assess_local_identifiability(rs_ode)
ifs_2 = find_identifiable_functions(rs_ode)

# Identifiability analysis for StructuralIdentifiability.jl model (declare this overwrites e.g. X2 variable etc.).
rs_si = @ODEmodel(
X1'(t) = -k1*X1(t) + k2*X2(t),
X2'(t) = k1*X1(t) - k2*X2(t),
rn2₊X3'(t) = -rn2₊k3*rn2₊X3(t) + rn2₊k4*rn2₊X4(t),
rn2₊X4'(t) = rn2₊k3*rn2₊X3(t) - rn2₊k4*rn2₊X4(t),
y1(t) = X1,
y2(t) = X2,
y3(t) = rn2₊X3,
y4(t) = k1
)
gi_3 = assess_identifiability(rs_si)
li_3 = assess_local_identifiability(rs_si)
ifs_3 = find_identifiable_functions(rs_si)

# Check outputs.
@test sym_dict(gi_1) == sym_dict(gi_3)
@test sym_dict(li_1) == sym_dict(li_3)
@test length(ifs_1)-2 == length(ifs_2)-2 == length(ifs_3) # In the first case, the conservation law parameter is also identifiable.

# Checks output for the SI converted version of the catalyst model.
# For nested systems with conservation laws, conserved quantities like Γ[1], cannot be replaced back.
# Hence, here you display identifiability for `Γ[1]` instead of X2.
gi_1_no_cq = filter(x -> !occursin("X2",String(x[1])) && !occursin("X4",String(x[1])), sym_dict(gi_1))
gi_2_no_cq = filter(x -> !occursin("Γ",String(x[1])), sym_dict(gi_2))
li_1_no_cq = filter(x -> !occursin("X2",String(x[1])) && !occursin("X4",String(x[1])), sym_dict(li_1))
li_2_no_cq = filter(x -> !occursin("Γ",String(x[1])), sym_dict(li_2))
@test gi_1_no_cq == gi_2_no_cq
@test li_1_no_cq == li_2_no_cq
end

# Tests directly on reaction systems with known identifiability structures.
# Test provided by Alexander Demin.
let
Expand Down Expand Up @@ -236,6 +289,5 @@ let
:x2 => :globally,
:x3 => :globally,
)
# Will probably be fixed in the 0.5 release of SI.jl
@test_broken find_identifiable_functions(rs, measured_quantities = [:x3])
@test length(find_identifiable_functions(rs, measured_quantities = [:x3])) == 1
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,6 @@ using SafeTestsets
### Tests extensions. ###
@time @safetestset "BifurcationKit Extension" begin include("extensions/bifurcation_kit.jl") end
@time @safetestset "HomotopyContinuation Extension" begin include("extensions/homotopy_continuation.jl") end
@time @safetestset "Structural Identifiability Extension" begin include("extensions/structural_identifiability.jl") end
@time @safetestset "Structural Identifiability Extension" begin include("extensions/structural_identifiability.jl") end

end # @time

0 comments on commit 40274e6

Please sign in to comment.