From 6cc222b18c8b28fc3c779e4bacd71b71c0c9153d Mon Sep 17 00:00:00 2001 From: pogudingleb Date: Sun, 9 Jun 2024 23:53:16 +0200 Subject: [PATCH 1/2] adding known_ic to the MTK interface --- ext/ModelingToolkitSIExt.jl | 79 ++++++++++++++---- src/StructuralIdentifiability.jl | 6 +- src/identifiable_functions.jl | 4 +- src/known_ic.jl | 4 +- test/extensions/modelingtoolkit.jl | 128 +++++++++++++++++++++++++++++ 5 files changed, 202 insertions(+), 19 deletions(-) diff --git a/ext/ModelingToolkitSIExt.jl b/ext/ModelingToolkitSIExt.jl index 01f19da2..8ee7bcd3 100644 --- a/ext/ModelingToolkitSIExt.jl +++ b/ext/ModelingToolkitSIExt.jl @@ -340,22 +340,25 @@ end # ------------------------------------------------------------------------------ """ - assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], prob_threshold = 0.99, loglevel=Logging.Info) + assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info) Input: - `ode` - the ModelingToolkit.ODESystem object that defines the model - `measured_quantities` - the output functions of the model - `funcs_to_check` - functions of parameters for which to check the identifiability +- `known_ic` - functions, for which initial conditions are assumed to be known - `prob_threshold` - probability of correctness. - `loglevel` - the minimal level of log messages to display (`Logging.Info` by default) Assesses identifiability (both local and global) of a given ODE model (parameters detected automatically). The result is guaranteed to be correct with the probability at least `prob_threshold`. +If known initial conditions are provided, the identifiability results for the states will also hold at `t = 0` """ function StructuralIdentifiability.assess_identifiability( ode::ModelingToolkit.ODESystem; measured_quantities = Array{ModelingToolkit.Equation}[], funcs_to_check = [], + known_ic = [], prob_threshold = 0.99, loglevel = Logging.Info, ) @@ -365,6 +368,7 @@ function StructuralIdentifiability.assess_identifiability( ode, measured_quantities = measured_quantities, funcs_to_check = funcs_to_check, + known_ic = known_ic, prob_threshold = prob_threshold, ) end @@ -374,6 +378,7 @@ function _assess_identifiability( ode::ModelingToolkit.ODESystem; measured_quantities = Array{ModelingToolkit.Equation}[], funcs_to_check = [], + known_ic = [], prob_threshold = 0.99, ) if isempty(measured_quantities) @@ -387,13 +392,32 @@ function _assess_identifiability( end funcs_to_check_ = [eval_at_nemo(each, conversion) for each in funcs_to_check] - result = StructuralIdentifiability._assess_identifiability( - ode, - funcs_to_check = funcs_to_check_, - prob_threshold = prob_threshold, - ) + known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic] + + nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check) + result = nothing + if isempty(known_ic) + result = StructuralIdentifiability._assess_identifiability( + ode, + funcs_to_check = funcs_to_check_, + prob_threshold = prob_threshold, + ) + return OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_) + else + result = StructuralIdentifiability._assess_identifiability_kic( + ode, + known_ic_, + funcs_to_check = funcs_to_check_, + prob_threshold = prob_threshold, + ) + end nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check) out_dict = OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_) + if length(known_ic) > 0 + @warn "Since known initial conditions were provided, identifiability of states (e.g., `x(t)`) is at t = 0 only !" + t = SymbolicUtils.Sym{Real}(:t) + out_dict = OrderedDict(substitute(k, Dict(t => 0)) => v for (k, v) in out_dict) + end return out_dict end @@ -510,7 +534,7 @@ end # ------------------------------------------------------------------------------ """ - find_identifiable_functions(ode::ModelingToolkit.ODESystem; measured_quantities=[], options...) + find_identifiable_functions(ode::ModelingToolkit.ODESystem; measured_quantities=[], known_ic=[], options...) Finds all functions of parameters/states that are identifiable in the given ODE system. @@ -519,6 +543,9 @@ system. This functions takes the following optional arguments: - `measured_quantities` - the output functions of the model. +- `known_ic` - a list of functions whose initial conditions are assumed to be known, + then the returned identifiable functions will be functions of parameters and + initial conditions, not states (this is an experimental functionality). - `loglevel` - the verbosity of the logging (can be Logging.Error, Logging.Warn, Logging.Info, Logging.Debug) @@ -549,6 +576,7 @@ find_identifiable_functions(de, measured_quantities = [y1 ~ x0]) function StructuralIdentifiability.find_identifiable_functions( ode::ModelingToolkit.ODESystem; measured_quantities = Array{ModelingToolkit.Equation}[], + known_ic = [], prob_threshold::Float64 = 0.99, seed = 42, with_states = false, @@ -562,6 +590,7 @@ function StructuralIdentifiability.find_identifiable_functions( return _find_identifiable_functions( ode, measured_quantities = measured_quantities, + known_ic = known_ic, prob_threshold = prob_threshold, seed = seed, with_states = with_states, @@ -574,6 +603,7 @@ end function _find_identifiable_functions( ode::ModelingToolkit.ODESystem; measured_quantities = Array{ModelingToolkit.Equation}[], + known_ic = [], prob_threshold::Float64 = 0.99, seed = 42, with_states = false, @@ -585,17 +615,36 @@ function _find_identifiable_functions( measured_quantities = get_measured_quantities(ode) end ode, conversion = mtk_to_si(ode, measured_quantities) - result = StructuralIdentifiability._find_identifiable_functions( - ode, - simplify = simplify, - prob_threshold = prob_threshold, - seed = seed, - with_states = with_states, - rational_interpolator = rational_interpolator, - ) + known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic] + result = nothing + if isempty(known_ic) + result = StructuralIdentifiability._find_identifiable_functions( + ode, + simplify = simplify, + prob_threshold = prob_threshold, + seed = seed, + with_states = with_states, + rational_interpolator = rational_interpolator, + ) + else + result = StructuralIdentifiability._find_identifiable_functions_kic( + ode, + known_ic_, + simplify = simplify, + prob_threshold = prob_threshold, + seed = seed, + rational_interpolator = rational_interpolator, + ) + end result = [parent_ring_change(f, ode.poly_ring) for f in result] nemo2mtk = Dict(v => Num(k) for (k, v) in conversion) out_funcs = [eval_at_dict(func, nemo2mtk) for func in result] + if length(known_ic) > 0 + @warn "Since known initial conditions were provided, identifiability of states (e.g., `x(t)`) is at t = 0 only !" + t = SymbolicUtils.Sym{Real}(:t) + out_funcs = [substitute(f, Dict(t => 0)) for f in out_funcs] + end + return out_funcs end diff --git a/src/StructuralIdentifiability.jl b/src/StructuralIdentifiability.jl index 3c6c0338..8cff59bc 100644 --- a/src/StructuralIdentifiability.jl +++ b/src/StructuralIdentifiability.jl @@ -118,12 +118,16 @@ function assess_identifiability( prob_threshold = prob_threshold, ) else - return _assess_identifiability_kic( + res = _assess_identifiability_kic( ode, known_ic, funcs_to_check = funcs_to_check, prob_threshold = prob_threshold, ) + funcs_to_check_ic = replace_with_ic(ode, funcs_to_check) + return OrderedDict( + f_ic => res[f] for (f, f_ic) in zip(funcs_to_check, funcs_to_check_ic) + ) end end end diff --git a/src/identifiable_functions.jl b/src/identifiable_functions.jl index fedcaf1f..583b1392 100644 --- a/src/identifiable_functions.jl +++ b/src/identifiable_functions.jl @@ -69,7 +69,7 @@ function find_identifiable_functions( rational_interpolator = rational_interpolator, ) else - return _find_identifiable_functions_kic( + id_funcs = _find_identifiable_functions_kic( ode, known_ic, prob_threshold = prob_threshold, @@ -77,6 +77,8 @@ function find_identifiable_functions( simplify = simplify, rational_interpolator = rational_interpolator, ) + # renaming variables from `x(t)` to `x(0)` + return replace_with_ic(ode, id_funcs) end end end diff --git a/src/known_ic.jl b/src/known_ic.jl index 0983e777..a61fcc79 100644 --- a/src/known_ic.jl +++ b/src/known_ic.jl @@ -59,7 +59,7 @@ function _find_identifiable_functions_kic( @info "The search for identifiable functions with known initial conditions concluded in $((time_ns() - runtime_start) / 1e9) seconds" - return replace_with_ic(ode, id_funcs) + return id_funcs end """ @@ -90,7 +90,7 @@ function _assess_identifiability_kic( end half_p = 0.5 + prob_threshold / 2 id_funcs = _find_identifiable_functions_kic(ode, known_ic, prob_threshold = half_p) - funcs_to_check = replace_with_ic(ode, funcs_to_check) + #funcs_to_check = replace_with_ic(ode, funcs_to_check) result = OrderedDict(f => :globally for f in funcs_to_check) half_p = 0.5 + half_p / 2 diff --git a/test/extensions/modelingtoolkit.jl b/test/extensions/modelingtoolkit.jl index 5d28de27..e9ea9944 100644 --- a/test/extensions/modelingtoolkit.jl +++ b/test/extensions/modelingtoolkit.jl @@ -717,4 +717,132 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt" @test issetequal(sym_dict(local_id_1), sym_dict(local_id_2)) @test length(ifs_1) == length(ifs_2) end + + @testset "Identifiability of MTK models with known generic initial conditions" begin + cases = [] + + @parameters a, b, c, d + @variables t, x1(t), x2(t) + D = Differential(t) + x1_0 = substitute(x1, Dict(t => 0)) + x2_0 = substitute(x2, Dict(t => 0)) + eqs = [D(x1) ~ a * x1 - b * x1 * x2, D(x2) ~ -c * x2 + d * x1 * x2] + ode_mtk = ODESystem(eqs, t, name = :lv) + push!( + cases, + Dict( + :ode => ode_mtk, + :measured => [x1], + :known => [x2], + :to_check => [], + :correct_funcs => [a, b, c, d, x1_0, x2_0], + :correct_ident => + OrderedDict(x => :globally for x in [x1_0, x2_0, a, b, c, d]), + ), + ) + + @parameters c + @variables x3(t) + x3_0 = substitute(x3, Dict(t => 0)) + eqs = [D(x1) ~ a + x2 + x3, D(x2) ~ b^2 + c, D(x3) ~ -c] + ode_mtk = ODESystem(eqs, t, name = :ex2) + + push!( + cases, + Dict( + :ode => ode_mtk, + :measured => [x1], + :known => [x2, x3], + :to_check => [], + :correct_funcs => [a, b^2, x1_0, x2_0, x3_0], + :correct_ident => OrderedDict( + x1_0 => :globally, + x2_0 => :globally, + x3_0 => :globally, + a => :globally, + b => :locally, + c => :nonidentifiable, + ), + ), + ) + + push!( + cases, + Dict( + :ode => ode_mtk, + :measured => [x1], + :known => [x2, x3], + :to_check => [b^2, x2 * c], + :correct_funcs => [a, b^2, x1_0, x2_0, x3_0], + :correct_ident => + OrderedDict(b^2 => :globally, x2_0 * c => :nonidentifiable), + ), + ) + + eqs = [D(x1) ~ a * x1, D(x2) ~ x1 + 1 / x1] + ode_mtk = ODESystem(eqs, t, name = :ex3) + push!( + cases, + Dict( + :ode => ode_mtk, + :measured => [x2], + :known => [x1], + :to_check => [], + :correct_funcs => [a, x1_0, x2_0], + :correct_ident => + OrderedDict(x1_0 => :globally, x2_0 => :globally, a => :globally), + ), + ) + + @parameters alpha, beta, gama, delta, sigma + @variables x4(t) + x4_0 = substitute(x4, Dict(t => 0)) + eqs = [ + D(x1) ~ -b * x1 + 1 / (c + x4), + D(x2) ~ alpha * x1 - beta * x2, + D(x3) ~ gama * x2 - delta * x3, + D(x4) ~ sigma * x4 * (gama * x2 - delta * x3) / x3, + ] + ode_mtk = ODESystem(eqs, t, name = :goodwin) + push!( + cases, + Dict( + :ode => ode_mtk, + :measured => [x1], + :known => [x2, x3], + :to_check => [alpha, alpha * gama], + :correct_funcs => [ + sigma, + c, + b, + x4_0, + x3_0, + x2_0, + x1_0, + beta * delta, + alpha * gama, + beta + delta, + -delta * x3_0 + gama * x2_0, + ], + :correct_ident => OrderedDict(alpha => :locally, alpha * gama => :globally), + ), + ) + + for case in cases + ode = case[:ode] + y = case[:measured] + known = case[:known] + result_funcs = + find_identifiable_functions(ode, known_ic = known, measured_quantities = y) + correct_funcs = @test Set(result_funcs) == Set(case[:correct_funcs]) + + result_ident = assess_identifiability( + ode, + known_ic = known, + measured_quantities = y, + funcs_to_check = case[:to_check], + ) + @test case[:correct_ident] == result_ident + end + end end From e60361cfff446485f6596db2210f248312fd3712 Mon Sep 17 00:00:00 2001 From: pogudingleb Date: Mon, 10 Jun 2024 00:26:27 +0200 Subject: [PATCH 2/2] fixing the empty output issue --- ext/ModelingToolkitSIExt.jl | 2 +- src/StructuralIdentifiability.jl | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ext/ModelingToolkitSIExt.jl b/ext/ModelingToolkitSIExt.jl index 8ee7bcd3..94655ec5 100644 --- a/ext/ModelingToolkitSIExt.jl +++ b/ext/ModelingToolkitSIExt.jl @@ -603,7 +603,7 @@ end function _find_identifiable_functions( ode::ModelingToolkit.ODESystem; measured_quantities = Array{ModelingToolkit.Equation}[], - known_ic = [], + known_ic = Array{Symbolics.Num}[], prob_threshold::Float64 = 0.99, seed = 42, with_states = false, diff --git a/src/StructuralIdentifiability.jl b/src/StructuralIdentifiability.jl index 8cff59bc..6ff7826a 100644 --- a/src/StructuralIdentifiability.jl +++ b/src/StructuralIdentifiability.jl @@ -124,10 +124,9 @@ function assess_identifiability( funcs_to_check = funcs_to_check, prob_threshold = prob_threshold, ) - funcs_to_check_ic = replace_with_ic(ode, funcs_to_check) - return OrderedDict( - f_ic => res[f] for (f, f_ic) in zip(funcs_to_check, funcs_to_check_ic) - ) + funcs = keys(res) + funcs_ic = replace_with_ic(ode, funcs) + return OrderedDict(f_ic => res[f] for (f, f_ic) in zip(funcs, funcs_ic)) end end end