From 8ee49c8907286066c2688bf11b41a9c1f398e19c Mon Sep 17 00:00:00 2001 From: Torkel Date: Thu, 1 Feb 2024 15:50:20 -0500 Subject: [PATCH] add tests --- test/extensions/modelingtoolkit.jl | 39 +++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/test/extensions/modelingtoolkit.jl b/test/extensions/modelingtoolkit.jl index d09ce810..d06257e5 100644 --- a/test/extensions/modelingtoolkit.jl +++ b/test/extensions/modelingtoolkit.jl @@ -658,6 +658,43 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt" end @testset "Exporting ModelingToolkit Model to SI Model" begin - # Add test of `mtk_to_si` function here, as well as identifiability functions when applied to its output. + + # Creates MTK model and assesses its identifiability. + @parameters r1, r2, c1, c2, beta1, beta2, chi1, chi2 + @variables t, x1(t), x2(t), y(t), u(t) + D= Differential(t) + eqs = [ + D(x1) ~ r1 * x1 * (1 - c1 * x1) + beta1 * x1 * x2 / (chi1 + x2) + u, + D(x2) ~ r2 * x2 * (1 - c2 * x2) + beta2 * x1 * x2 / (chi2 + x1), + ] + measured_quantities = [y ~ x1] + ode_mtk = ODESystem(eqs, t, name = :mutualist) + + global_id_1 = assess_identifiability(ode_mtk, measured_quantities = measured_quantities) + local_id_1 = assess_local_identifiability(ode_mtk, measured_quantities = measured_quantities) + ifs_1 = find_identifiable_functions(ode_mtk, measured_quantities = measured_quantities) + + # Converts mtk model to si model, and assesses its identifiability. + si_model, _ = mtk_to_si(ode_mtk, measured_quantities) + global_id_2 = assess_identifiability(si_model) + local_id_2 = assess_local_identifiability(si_model) + ifs_2 = find_identifiable_functions(si_model) + + # Converts the output dicts from StructuralIdentifiability functions from "weird symbol => stuff" to "symbol => stuff" (the output have some strange meta data which prevents equality checks, this enables this). + # Structural identifiability also provides variables like x (rather than x(t)). This is a bug, but we have to convert to make it work (now just remove any (t) to make them all equal). + function sym_dict(dict_in) + dict_out = Dict{Symbol,Any}() + for key in keys(dict_in) + sym_key = Symbol(key) + sym_key = Symbol(replace(String(sym_key), "(t)" => "")) + dict_out[sym_key] = dict_in[key] + end + return dict_out + end + + # Checks that the two approaches yields the same result + @test issetequal(sym_dict(local_id_1), sym_dict(local_id_2)) + @test issetequal(sym_dict(local_id_1), sym_dict(local_id_2)) + @test length(ifs_1) == length(ifs_2) end end