Skip to content

Commit

Permalink
Merge pull request #247 from SciML/ordering_output
Browse files Browse the repository at this point in the history
Ordering the elements in the output
  • Loading branch information
pogudingleb authored Nov 14, 2023
2 parents 2ff1a9d + f0a111d commit c5738cd
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 80 deletions.
2 changes: 2 additions & 0 deletions src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ Here,
macro ODEmodel(ex::Expr...)
equations = [ex...]
x_vars, y_vars, u_vars, all_symb = macrohelper_extract_vars(equations)
# ensures that the parameters will be ordered
all_symb = sort(all_symb)

# creating the polynomial ring
vars_list = :([$(all_symb...)])
Expand Down
10 changes: 5 additions & 5 deletions src/StructuralIdentifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Input:
Assesses identifiability of a given ODE model. The result is guaranteed to be correct with the probability
at least `p`.
The function returns a dictionary from the functions to check to their identifiability properties
The function returns an (ordered) dictionary from the functions to check to their identifiability properties
(one of `:nonidentifiable`, `:locally`, `:globally`).
"""
function assess_identifiability(
Expand All @@ -118,7 +118,7 @@ function _assess_identifiability(
p_loc = 1 - (1 - p) * 0.1

if isempty(funcs_to_check)
funcs_to_check = vcat(ode.parameters, ode.x_vars)
funcs_to_check = vcat(ode.x_vars, ode.parameters)
end

@info "Assessing local identifiability"
Expand Down Expand Up @@ -148,7 +148,7 @@ function _assess_identifiability(
@info "Global identifiability assessed in $runtime seconds"
_runtime_logger[:glob_time] = runtime

result = Dict{Any, Symbol}()
result = OrderedDict{Any, Symbol}()
glob_ind = 1
for i in 1:length(funcs_to_check)
if !local_result[funcs_to_check[i]]
Expand Down Expand Up @@ -210,13 +210,13 @@ function _assess_identifiability(
ode, conversion = mtk_to_si(ode, measured_quantities)
conversion_back = Dict(v => k for (k, v) in conversion)
if isempty(funcs_to_check)
funcs_to_check = [conversion_back[x] for x in [ode.parameters..., ode.x_vars...]]
funcs_to_check = [conversion_back[x] for x in [ode.x_vars..., ode.parameters...]]
end
funcs_to_check_ = [eval_at_nemo(each, conversion) for each in funcs_to_check]

result = _assess_identifiability(ode, funcs_to_check = funcs_to_check_, p = p)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = Dict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
out_dict = OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
return out_dict
end

Expand Down
10 changes: 5 additions & 5 deletions src/discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ function _assess_local_identifiability_discrete(

@debug "Computing the result"
base_rank = LinearAlgebra.rank(Jac)
result = Dict{Any, Bool}()
result = OrderedDict{Any, Bool}()
for i in 1:length(funcs_to_check)
for (k, p) in enumerate(dds_ext.parameters)
Jac[k, 1] =
Expand All @@ -280,7 +280,7 @@ function _assess_local_identifiability_discrete(
result[funcs_to_check[i]] = LinearAlgebra.rank(Jac) == base_rank
end

return Dict(result)
return result
end

# ------------------------------------------------------------------------------
Expand All @@ -301,7 +301,7 @@ Input:
- `p` - probability of correctness
Output:
- the result is a dictionary from each function to to boolean;
- the result is an (ordered) dictionary from each function to to boolean;
The result is correct with probability at least `p`.
"""
Expand Down Expand Up @@ -331,16 +331,16 @@ function assess_local_identifiability(
dds_aux, conversion = mtk_to_si(dds, measured_quantities)
if length(funcs_to_check) == 0
funcs_to_check = vcat(
parameters(dds),
[x for x in states(dds) if conversion[x] in dds_aux.x_vars],
parameters(dds),
)
end
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
known_ic_ = [eval_at_nemo(x, conversion) for x in known_ic]

result = _assess_local_identifiability_discrete(dds_aux, funcs_to_check_, known_ic_, p)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = Dict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
out_dict = OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
if length(known_ic) > 0
@warn "Since known initial conditions were provided, identifiability of states (e.g., `x(t)`) is at t = 0 only !"
end
Expand Down
18 changes: 10 additions & 8 deletions src/local_identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Input:
- `loglevel` - the minimal level of log messages to display (`Logging.Info` by default)
Output:
- for `type=:SE`, the result is a dictionary from each parameter to boolean;
- for `type=:SE`, the result is an (ordered) dictionary from each parameter to boolean;
- for `type=:ME`, the result is a tuple with the dictionary as in `:SE` case and array of number of experiments.
The function determines local identifiability of parameters in `funcs_to_check` or all possible parameters if `funcs_to_check` is empty
Expand Down Expand Up @@ -208,8 +208,8 @@ end
end
if length(funcs_to_check) == 0
funcs_to_check = vcat(
ModelingToolkit.parameters(ode),
[e for e in ModelingToolkit.states(ode) if !ModelingToolkit.isoutput(e)],
ModelingToolkit.parameters(ode),
)
end
ode, conversion = mtk_to_si(ode, measured_quantities)
Expand All @@ -223,7 +223,8 @@ end
type = type,
)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = Dict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
out_dict =
OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
return out_dict
elseif isequal(type, :ME)
result, bd = _assess_local_identifiability(
Expand All @@ -233,7 +234,8 @@ end
type = type,
)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = Dict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
out_dict =
OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
return (out_dict, bd)
end
end
Expand Down Expand Up @@ -280,7 +282,7 @@ function _assess_local_identifiability(
if isempty(funcs_to_check)
funcs_to_check = ode.parameters
if type == :SE
funcs_to_check = vcat(funcs_to_check, ode.x_vars)
funcs_to_check = vcat(ode.x_vars, ode.parameters)
end
end

Expand Down Expand Up @@ -440,7 +442,7 @@ function _assess_local_identifiability(

@debug "Computing the result"
base_rank = LinearAlgebra.rank(Jac)
result = Dict{Any, Bool}()
result = OrderedDict{Any, Bool}()
for i in 1:length(funcs_to_check)
for (k, p) in enumerate(ode_red.parameters)
Jac[k, 1] =
Expand All @@ -459,9 +461,9 @@ function _assess_local_identifiability(
# NB: the Jac contains now the derivatives of the last from `funcs_to_check`

if type == :SE
return Dict(result)
return result
end
return (Dict(result), num_exp)
return (result, num_exp)
end

# ------------------------------------------------------------------------------
22 changes: 11 additions & 11 deletions test/identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -38,13 +38,13 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

# Also test when `funcs_to_test` is empty!
funcs_to_test = Vector{typeof(x1)}()
correct = Dict(x1 => :nonidentifiable, x2 => :nonidentifiable)
correct = OrderedDict(x1 => :nonidentifiable, x2 => :nonidentifiable)
push!(test_cases, Dict(:ode => ode, :funcs => funcs_to_test, :correct => correct))

#--------------------------------------------------------------------------
Expand All @@ -69,7 +69,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -89,7 +89,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand Down Expand Up @@ -118,7 +118,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -137,7 +137,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -155,7 +155,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -173,7 +173,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -192,7 +192,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand Down Expand Up @@ -222,7 +222,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand Down
12 changes: 6 additions & 6 deletions test/local_identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -42,7 +42,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -62,7 +62,7 @@
Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => Dict(funcs_to_test .=> correct),
:correct => OrderedDict(funcs_to_test .=> correct),
),
)

Expand All @@ -76,7 +76,7 @@
y(t) = x1(t)
)
funcs_to_test = [b, c, alpha, beta, delta, gama, beta + delta, beta * delta]
correct = Dict([
correct = OrderedDict([
b => true,
c => true,
alpha => false,
Expand All @@ -96,7 +96,7 @@
y(t) = x1(t) * x2(t)
)
funcs_to_test = [a1, a2, a3, a4, a5, a6, a7, a8]
correct = Dict(a => false for a in funcs_to_test)
correct = OrderedDict(a => false for a in funcs_to_test)
push!(test_cases, Dict(:ode => ode, :funcs => funcs_to_test, :correct => correct))

#--------------------------------------------------------------------------
Expand All @@ -109,7 +109,7 @@
y2(t) = x3(t)
)
funcs_to_test = [x1, x2, x1 + x2]
correct = Dict(x1 => false, x2 => false, x1 + x2 => true)
correct = OrderedDict(x1 => false, x2 => false, x1 + x2 => true)
push!(test_cases, Dict(:ode => ode, :funcs => funcs_to_test, :correct => correct))

#--------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit c5738cd

Please sign in to comment.