Skip to content

Commit

Permalink
Merge pull request #331 from SciML/kic_mtk
Browse files Browse the repository at this point in the history
Adding known_ic (known generic initial conditions) to the MTK interface
  • Loading branch information
pogudingleb authored Jun 10, 2024
2 parents e0ef076 + e60361c commit 2adfdc3
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 19 deletions.
79 changes: 64 additions & 15 deletions ext/ModelingToolkitSIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,22 +340,25 @@ end
# ------------------------------------------------------------------------------

"""
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], prob_threshold = 0.99, loglevel=Logging.Info)
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
Input:
- `ode` - the ModelingToolkit.ODESystem object that defines the model
- `measured_quantities` - the output functions of the model
- `funcs_to_check` - functions of parameters for which to check the identifiability
- `known_ic` - functions, for which initial conditions are assumed to be known
- `prob_threshold` - probability of correctness.
- `loglevel` - the minimal level of log messages to display (`Logging.Info` by default)
Assesses identifiability (both local and global) of a given ODE model (parameters detected automatically). The result is guaranteed to be correct with the probability
at least `prob_threshold`.
If known initial conditions are provided, the identifiability results for the states will also hold at `t = 0`
"""
function StructuralIdentifiability.assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
funcs_to_check = [],
known_ic = [],
prob_threshold = 0.99,
loglevel = Logging.Info,
)
Expand All @@ -365,6 +368,7 @@ function StructuralIdentifiability.assess_identifiability(
ode,
measured_quantities = measured_quantities,
funcs_to_check = funcs_to_check,
known_ic = known_ic,
prob_threshold = prob_threshold,
)
end
Expand All @@ -374,6 +378,7 @@ function _assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
funcs_to_check = [],
known_ic = [],
prob_threshold = 0.99,
)
if isempty(measured_quantities)
Expand All @@ -387,13 +392,32 @@ function _assess_identifiability(
end
funcs_to_check_ = [eval_at_nemo(each, conversion) for each in funcs_to_check]

result = StructuralIdentifiability._assess_identifiability(
ode,
funcs_to_check = funcs_to_check_,
prob_threshold = prob_threshold,
)
known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic]

nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
result = nothing
if isempty(known_ic)
result = StructuralIdentifiability._assess_identifiability(
ode,
funcs_to_check = funcs_to_check_,
prob_threshold = prob_threshold,
)
return OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
else
result = StructuralIdentifiability._assess_identifiability_kic(
ode,
known_ic_,
funcs_to_check = funcs_to_check_,
prob_threshold = prob_threshold,
)
end
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
if length(known_ic) > 0
@warn "Since known initial conditions were provided, identifiability of states (e.g., `x(t)`) is at t = 0 only !"
t = SymbolicUtils.Sym{Real}(:t)
out_dict = OrderedDict(substitute(k, Dict(t => 0)) => v for (k, v) in out_dict)
end
return out_dict
end

Expand Down Expand Up @@ -510,7 +534,7 @@ end
# ------------------------------------------------------------------------------

"""
find_identifiable_functions(ode::ModelingToolkit.ODESystem; measured_quantities=[], options...)
find_identifiable_functions(ode::ModelingToolkit.ODESystem; measured_quantities=[], known_ic=[], options...)
Finds all functions of parameters/states that are identifiable in the given ODE
system.
Expand All @@ -519,6 +543,9 @@ system.
This functions takes the following optional arguments:
- `measured_quantities` - the output functions of the model.
- `known_ic` - a list of functions whose initial conditions are assumed to be known,
then the returned identifiable functions will be functions of parameters and
initial conditions, not states (this is an experimental functionality).
- `loglevel` - the verbosity of the logging
(can be Logging.Error, Logging.Warn, Logging.Info, Logging.Debug)
Expand Down Expand Up @@ -549,6 +576,7 @@ find_identifiable_functions(de, measured_quantities = [y1 ~ x0])
function StructuralIdentifiability.find_identifiable_functions(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
known_ic = [],
prob_threshold::Float64 = 0.99,
seed = 42,
with_states = false,
Expand All @@ -562,6 +590,7 @@ function StructuralIdentifiability.find_identifiable_functions(
return _find_identifiable_functions(
ode,
measured_quantities = measured_quantities,
known_ic = known_ic,
prob_threshold = prob_threshold,
seed = seed,
with_states = with_states,
Expand All @@ -574,6 +603,7 @@ end
function _find_identifiable_functions(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
known_ic = Array{Symbolics.Num}[],
prob_threshold::Float64 = 0.99,
seed = 42,
with_states = false,
Expand All @@ -585,17 +615,36 @@ function _find_identifiable_functions(
measured_quantities = get_measured_quantities(ode)
end
ode, conversion = mtk_to_si(ode, measured_quantities)
result = StructuralIdentifiability._find_identifiable_functions(
ode,
simplify = simplify,
prob_threshold = prob_threshold,
seed = seed,
with_states = with_states,
rational_interpolator = rational_interpolator,
)
known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic]
result = nothing
if isempty(known_ic)
result = StructuralIdentifiability._find_identifiable_functions(
ode,
simplify = simplify,
prob_threshold = prob_threshold,
seed = seed,
with_states = with_states,
rational_interpolator = rational_interpolator,
)
else
result = StructuralIdentifiability._find_identifiable_functions_kic(
ode,
known_ic_,
simplify = simplify,
prob_threshold = prob_threshold,
seed = seed,
rational_interpolator = rational_interpolator,
)
end
result = [parent_ring_change(f, ode.poly_ring) for f in result]
nemo2mtk = Dict(v => Num(k) for (k, v) in conversion)
out_funcs = [eval_at_dict(func, nemo2mtk) for func in result]
if length(known_ic) > 0
@warn "Since known initial conditions were provided, identifiability of states (e.g., `x(t)`) is at t = 0 only !"
t = SymbolicUtils.Sym{Real}(:t)
out_funcs = [substitute(f, Dict(t => 0)) for f in out_funcs]
end

return out_funcs
end

Expand Down
5 changes: 4 additions & 1 deletion src/StructuralIdentifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ function assess_identifiability(
prob_threshold = prob_threshold,
)
else
return _assess_identifiability_kic(
res = _assess_identifiability_kic(
ode,
known_ic,
funcs_to_check = funcs_to_check,
prob_threshold = prob_threshold,
)
funcs = keys(res)
funcs_ic = replace_with_ic(ode, funcs)
return OrderedDict(f_ic => res[f] for (f, f_ic) in zip(funcs, funcs_ic))
end
end
end
Expand Down
4 changes: 3 additions & 1 deletion src/identifiable_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,16 @@ function find_identifiable_functions(
rational_interpolator = rational_interpolator,
)
else
return _find_identifiable_functions_kic(
id_funcs = _find_identifiable_functions_kic(
ode,
known_ic,
prob_threshold = prob_threshold,
seed = seed,
simplify = simplify,
rational_interpolator = rational_interpolator,
)
# renaming variables from `x(t)` to `x(0)`
return replace_with_ic(ode, id_funcs)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/known_ic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function _find_identifiable_functions_kic(

@info "The search for identifiable functions with known initial conditions concluded in $((time_ns() - runtime_start) / 1e9) seconds"

return replace_with_ic(ode, id_funcs)
return id_funcs
end

"""
Expand Down Expand Up @@ -90,7 +90,7 @@ function _assess_identifiability_kic(
end
half_p = 0.5 + prob_threshold / 2
id_funcs = _find_identifiable_functions_kic(ode, known_ic, prob_threshold = half_p)
funcs_to_check = replace_with_ic(ode, funcs_to_check)
#funcs_to_check = replace_with_ic(ode, funcs_to_check)
result = OrderedDict(f => :globally for f in funcs_to_check)

half_p = 0.5 + half_p / 2
Expand Down
128 changes: 128 additions & 0 deletions test/extensions/modelingtoolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,4 +717,132 @@ if GROUP == "All" || GROUP == "ModelingToolkitSIExt"
@test issetequal(sym_dict(local_id_1), sym_dict(local_id_2))
@test length(ifs_1) == length(ifs_2)
end

@testset "Identifiability of MTK models with known generic initial conditions" begin
cases = []

@parameters a, b, c, d
@variables t, x1(t), x2(t)
D = Differential(t)
x1_0 = substitute(x1, Dict(t => 0))
x2_0 = substitute(x2, Dict(t => 0))
eqs = [D(x1) ~ a * x1 - b * x1 * x2, D(x2) ~ -c * x2 + d * x1 * x2]
ode_mtk = ODESystem(eqs, t, name = :lv)
push!(
cases,
Dict(
:ode => ode_mtk,
:measured => [x1],
:known => [x2],
:to_check => [],
:correct_funcs => [a, b, c, d, x1_0, x2_0],
:correct_ident =>
OrderedDict(x => :globally for x in [x1_0, x2_0, a, b, c, d]),
),
)

@parameters c
@variables x3(t)
x3_0 = substitute(x3, Dict(t => 0))
eqs = [D(x1) ~ a + x2 + x3, D(x2) ~ b^2 + c, D(x3) ~ -c]
ode_mtk = ODESystem(eqs, t, name = :ex2)

push!(
cases,
Dict(
:ode => ode_mtk,
:measured => [x1],
:known => [x2, x3],
:to_check => [],
:correct_funcs => [a, b^2, x1_0, x2_0, x3_0],
:correct_ident => OrderedDict(
x1_0 => :globally,
x2_0 => :globally,
x3_0 => :globally,
a => :globally,
b => :locally,
c => :nonidentifiable,
),
),
)

push!(
cases,
Dict(
:ode => ode_mtk,
:measured => [x1],
:known => [x2, x3],
:to_check => [b^2, x2 * c],
:correct_funcs => [a, b^2, x1_0, x2_0, x3_0],
:correct_ident =>
OrderedDict(b^2 => :globally, x2_0 * c => :nonidentifiable),
),
)

eqs = [D(x1) ~ a * x1, D(x2) ~ x1 + 1 / x1]
ode_mtk = ODESystem(eqs, t, name = :ex3)
push!(
cases,
Dict(
:ode => ode_mtk,
:measured => [x2],
:known => [x1],
:to_check => [],
:correct_funcs => [a, x1_0, x2_0],
:correct_ident =>
OrderedDict(x1_0 => :globally, x2_0 => :globally, a => :globally),
),
)

@parameters alpha, beta, gama, delta, sigma
@variables x4(t)
x4_0 = substitute(x4, Dict(t => 0))
eqs = [
D(x1) ~ -b * x1 + 1 / (c + x4),
D(x2) ~ alpha * x1 - beta * x2,
D(x3) ~ gama * x2 - delta * x3,
D(x4) ~ sigma * x4 * (gama * x2 - delta * x3) / x3,
]
ode_mtk = ODESystem(eqs, t, name = :goodwin)
push!(
cases,
Dict(
:ode => ode_mtk,
:measured => [x1],
:known => [x2, x3],
:to_check => [alpha, alpha * gama],
:correct_funcs => [
sigma,
c,
b,
x4_0,
x3_0,
x2_0,
x1_0,
beta * delta,
alpha * gama,
beta + delta,
-delta * x3_0 + gama * x2_0,
],
:correct_ident => OrderedDict(alpha => :locally, alpha * gama => :globally),
),
)

for case in cases
ode = case[:ode]
y = case[:measured]
known = case[:known]
result_funcs =
find_identifiable_functions(ode, known_ic = known, measured_quantities = y)
correct_funcs = @test Set(result_funcs) == Set(case[:correct_funcs])

result_ident = assess_identifiability(
ode,
known_ic = known,
measured_quantities = y,
funcs_to_check = case[:to_check],
)
@test case[:correct_ident] == result_ident
end
end
end

0 comments on commit 2adfdc3

Please sign in to comment.