Skip to content

Commit

Permalink
Merge pull request #248 from SciML/t_dependency
Browse files Browse the repository at this point in the history
Making dependency on t explicit in the output
  • Loading branch information
pogudingleb authored Nov 17, 2023
2 parents c5738cd + bb350d0 commit e062848
Show file tree
Hide file tree
Showing 14 changed files with 340 additions and 144 deletions.
76 changes: 57 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ The returned value is a dictionary from the parameter of the model to one of the
For example, for the `ode` defined above, it will be

```
Dict{Any, Symbol} with 7 entries:
a12 => :locally
a21 => :globally
a01 => :locally
b => :nonidentifiable
x2 => :globally
x1 => :locally
x3 => :nonidentifiable
OrderedDict{Any, Symbol} with 7 entries:
x1(t) => :locally
x2(t) => :globally
x3(t) => :nonidentifiable
a01 => :locally
a12 => :locally
a21 => :globally
b => :nonidentifiable
```

If one is interested in the identifiability of particular functions of the parameter, one can pass a list of them as a second argument:
Expand All @@ -94,9 +94,47 @@ assess_identifiability(ode, funcs_to_check = [a01 + a12, a01 * a12])
This will return:

```
Dict{Any, Symbol} with 2 entries:
a12 + a01 => :globally
a12*a01 => :globally
OrderedDict{Any, Symbol} with 2 entries:
a01 + a12 => :globally
a01*a12 => :globally
```

### Finding identifiable functions

In the example above, we saw that, while some parameters may be not globally identifiable, appropriate functions of them (such as `a01 + a12` and `a01 * a12`)
can be still identifiable. However, it may be not so easy to guess these functions (even in this example). Good news is that this is not needed!
Function `find_identifiable_functions` can find generators of all identifiable functions of a given model. For instance:

```julia
find_identifiable_functions(ode)
```

will return

```
3-element Vector{AbstractAlgebra.Generic.Frac{Nemo.QQMPolyRingElem}}:
a21
a01*a12
a01 + a12
```

which are exactly the identifiable functions we have found before. Furthermore, by specifying `with_states = true`, one can compute the generating set for
all identifiable functions of parameters and states (in other words, all observable functions):

```julia
find_identifiable_functions(ode, with_states = true)
```

This will return

```
6-element Vector{AbstractAlgebra.Generic.Frac{Nemo.QQMPolyRingElem}}:
x2(t)
a21
a01*a12
a01 + a12
x3(t)//(a12*b + a21*b + b)
(-x1(t)*a21*b + x2(t)*a12*b + x2(t)*a21*b + x2(t)*b + x3(t))//(a21*b)
```

### Assessing local identifiability
Expand All @@ -111,14 +149,14 @@ assess_local_identifiability(ode)
The returned value is a dictionary from parameters and state variables to `1` (is locally identifiable/observable) and `0` (not identifiable/observable) values. In our example:

```
Dict{Nemo.fmpq_mpoly, Bool} with 7 entries:
a12 => 1
a21 => 1
x3 => 0
a01 => 1
x2 => 1
x1 => 1
b => 0
OrderedDict{Any, Bool} with 7 entries:
x1(t) => 1
x2(t) => 1
x3(t) => 0
a01 => 1
a12 => 1
a21 => 1
b => 0
```

As for `assess_identifiability`, one can assess local identifiability of arbitrary rational functions in the parameters (and also states) by providing a list of such functions as the second argument.
Expand Down
17 changes: 10 additions & 7 deletions docs/src/tutorials/discrete_time.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Now we consider a discrete-time model in the state-space form

$\begin{cases}
\mathbf{x}(t + 1) = \mathbf{f}(\mathbf{x}(t), \mathbf{p}, \mathbf{u}(t)),\\
\Delta\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t), \mathbf{p}, \mathbf{u}(t)),\\
\mathbf{y}(t) = \mathbf{g}(\mathbf{x}(t), \mathbf{p}, \mathbf{u(t)}),
\end{cases}$
\end{cases} \quad \text{where} \quad \Delta \mathbf{x}(t) := \mathbf{x}(t + 1) - \mathbf{x}(t)$

where $\mathbf{x}(t), \mathbf{y}(t)$, and $\mathbf{u}(t)$ are time-dependent states, outputs, and inputs, respectively,
and $\mathbf{x}(t), \mathbf{y}(t)$, and $\mathbf{u}(t)$ are time-dependent states, outputs, and inputs, respectively,
and $\mathbf{p}$ are scalar parameters.
As in the ODE case, we will call that a parameter or a states (or a function of them) is **identifiable** if its value can be recovered from
time series for inputs and outputs (in the generic case, see Definition 3 in [^1] for details).
Expand All @@ -22,9 +22,9 @@ and below we will describe how this can be done.
As a running example, we will use the following discrete version of the [SIR](https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#The_SIR_model) model:

$\begin{cases}
S(t + 1) = S(t) - \beta S(t) I(t),\\
I(t + 1) = I(t) + \beta S(t) I(t) - \alpha I(t),\\
R(t + 1) = R(t) + \alpha I(t),\\
\Delta S(t) = S(t) - \beta S(t) I(t),\\
\Delta I(t) = I(t) + \beta S(t) I(t) - \alpha I(t),\\
\Delta R(t) = R(t) + \alpha I(t),\\
y(t) = I(t),
\end{cases}$

Expand Down Expand Up @@ -58,7 +58,7 @@ In principle, it is not required to give a name to the observable, so one can wr
assess_local_identifiability(sir; measured_quantities = [I])
```

The `assess_local_identifiability` function has two important keyword arguments:
The `assess_local_identifiability` function has three important keyword arguments:

- `funcs_to_check` is a list of functions for which one want to assess identifiability, for example, the following code
will check if `β * S` is locally identifiable.
Expand All @@ -71,6 +71,9 @@ assess_local_identifiability(sir; measured_quantities = [I], funcs_to_check = [
principle it may produce incorrect result but the probability of correctness of the returned result is guaranteed to be at least `p`
(in fact, the employed bounds are quite conservative, so in practice incorrect result is almost never produced).

- `known_ic` is a list of the states for which initial conditions are known. In this case, the identifiability results will be valid not
at any time point `t` but only at `t = 0`.

As other main functions in the package, `assess_local_identifiability` accepts an optional parameter `loglevel` (default: `Logging.Info`)
to adjust the verbosity of logging.

Expand Down
38 changes: 16 additions & 22 deletions src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,22 @@ Here,
macro ODEmodel(ex::Expr...)
equations = [ex...]
x_vars, y_vars, u_vars, all_symb = macrohelper_extract_vars(equations)
# ensures that the parameters will be ordered
all_symb = sort(all_symb)
time_dependent = vcat(x_vars, y_vars, u_vars)
params = sort([s for s in all_symb if !(s in time_dependent)])
all_symb_no_t = vcat(time_dependent, params)
all_symb_with_t = vcat([:($s(t)) for s in time_dependent], params)

# creating the polynomial ring
vars_list = :([$(all_symb...)])
vars_list = :([$(all_symb_with_t...)])
R = gensym()
vars_aux = gensym()
exp_ring = :(
($R, $vars_aux) = StructuralIdentifiability.Nemo.PolynomialRing(
StructuralIdentifiability.Nemo.QQ,
map(string, $all_symb),
map(string, $all_symb_with_t),
)
)
assignments = [:($(all_symb[i]) = $vars_aux[$i]) for i in 1:length(all_symb)]
assignments = [:($(all_symb_no_t[i]) = $vars_aux[$i]) for i in 1:length(all_symb_no_t)]

# setting x_vars and y_vars in the right order
vx = gensym()
Expand Down Expand Up @@ -446,12 +448,7 @@ macro ODEmodel(ex::Expr...)
end
end

params = setdiff(all_symb, union(x_vars, y_vars, u_vars))
allnames = map(
string,
vcat(collect(x_vars), collect(params), collect(u_vars), collect(y_vars)),
)
for n in allnames
for n in all_symb_no_t
if !Base.isidentifier(n)
throw(
ArgumentError(
Expand Down Expand Up @@ -501,21 +498,18 @@ end
#------------------------------------------------------------------------------

function Base.show(io::IO, ode::ODE)
varstr =
Dict(x => var_to_str(x) * "(t)" for x in vcat(ode.x_vars, ode.u_vars, ode.y_vars))
merge!(varstr, Dict(p => var_to_str(p) for p in ode.parameters))
R_print, vars_print = Nemo.PolynomialRing(
base_ring(ode.poly_ring),
[varstr[v] for v in gens(ode.poly_ring)],
)
for x in ode.x_vars
print(io, var_to_str(x) * "'(t) = ")
print(io, evaluate(ode.x_equations[x], vars_print))
if endswith(var_to_str(x), "(t)")
print(io, var_to_str(x)[1:(end - 3)] * "'(t) = ")
else
print(io, var_to_str(x) * " = ")
end
print(io, ode.x_equations[x])
print(io, "\n")
end
for y in ode.y_vars
print(io, var_to_str(y) * "(t) = ")
print(io, evaluate(ode.y_equations[y], vars_print))
print(io, var_to_str(y) * " = ")
print(io, ode.y_equations[y])
print(io, "\n")
end
end
Expand Down
51 changes: 44 additions & 7 deletions src/discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,17 @@ function _degree_with_common_denom(polys)
end

"""
_assess_local_identifiability_discrete(dds::ODE{P}, funcs_to_check::Array{<: Any, 1}, known_ic, p::Float64=0.99) where P <: MPolyElem{Nemo.fmpq}
_assess_local_identifiability_discrete_aux(dds::ODE{P}, funcs_to_check::Array{<: Any, 1}, known_ic, p::Float64=0.99) where P <: MPolyElem{Nemo.fmpq}
Checks the local identifiability/observability of the functions in `funcs_to_check` treating `dds` as a discrete-time system.
Checks the local identifiability/observability of the functions in `funcs_to_check` treating `dds` as a discrete-time system with **shift**
instead of derivative in the right-hand side.
The result is correct with probability at least `p`.
`known_ic` can take one of the following
* `:none` - no initial conditions are assumed to be known
* `:all` - all initial conditions are assumed to be known
* a list of rational functions in states and parameters assumed to be known at t = 0
"""
function _assess_local_identifiability_discrete(
function _assess_local_identifiability_discrete_aux(
dds::ODE{P},
funcs_to_check::Array{<:Any, 1},
known_ic = :none,
Expand All @@ -208,6 +209,7 @@ function _assess_local_identifiability_discrete(

@debug "Computing the observability matrix"
prec = length(dds.x_vars) + length(dds.parameters)
@debug "The truncation order is $prec"

# Computing the bound from the Schwartz-Zippel-DeMilo-Lipton lemma
deg_x = _degree_with_common_denom(values(dds.x_equations))
Expand Down Expand Up @@ -294,7 +296,7 @@ end
p::Float64=0.99)
Input:
- `dds` - the DiscreteSystem object from ModelingToolkit
- `dds` - the DiscreteSystem object from ModelingToolkit (with **difference** operator in the right-hand side)
- `measured_quantities` - the measurable outputs of the model
- `funcs_to_check` - functions of parameters for which to check identifiability (all parameters and states if not specified)
- `known_ic` - functions (of states and parameter) whose initial conditions are assumed to be known
Expand All @@ -311,6 +313,26 @@ function assess_local_identifiability(
funcs_to_check = Array{}[],
known_ic = Array{}[],
p::Float64 = 0.99,
loglevel = Logging.Info,
)
restart_logging(loglevel = loglevel)
with_logger(_si_logger[]) do
return _assess_local_identifiability(
dds,
measured_quantities = measured_quantities,
funcs_to_check = funcs_to_check,
known_ic = known_ic,
p = p,
)
end
end

function _assess_local_identifiability(
dds::ModelingToolkit.DiscreteSystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
funcs_to_check = Array{}[],
known_ic = Array{}[],
p::Float64 = 0.99,
)
if length(measured_quantities) == 0
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(dds))
Expand All @@ -328,21 +350,36 @@ function assess_local_identifiability(
end
end

dds_aux, conversion = mtk_to_si(dds, measured_quantities)
# Converting the finite difference operator in the right-hand side to
# the corresponding shift operator
eqs = filter(eq -> !(ModelingToolkit.isoutput(eq.lhs)), ModelingToolkit.equations(dds))
deltas = [Symbolics.operation(e.lhs).dt for e in eqs]
@assert length(Set(deltas)) == 1
eqs_shift = [e.lhs ~ e.rhs + first(Symbolics.arguments(e.lhs)) for e in eqs]
dds_shift = DiscreteSystem(eqs_shift, name = gensym())
@debug "System transformed from difference to shift: $dds_shift"

dds_aux, conversion = mtk_to_si(dds_shift, measured_quantities)
if length(funcs_to_check) == 0
params = parameters(dds)
params_from_measured_quantities = union(
[filter(s -> !istree(s), get_variables(y)) for y in measured_quantities]...,
)
funcs_to_check = vcat(
[x for x in states(dds) if conversion[x] in dds_aux.x_vars],
parameters(dds),
union(params, params_from_measured_quantities),
)
end
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
known_ic_ = [eval_at_nemo(x, conversion) for x in known_ic]

result = _assess_local_identifiability_discrete(dds_aux, funcs_to_check_, known_ic_, p)
result =
_assess_local_identifiability_discrete_aux(dds_aux, funcs_to_check_, known_ic_, p)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
if length(known_ic) > 0
@warn "Since known initial conditions were provided, identifiability of states (e.g., `x(t)`) is at t = 0 only !"
out_dict = OrderedDict(substitute(k, Dict(t => 0)) => v for (k, v) in out_dict)
end
return out_dict
end
Expand Down
32 changes: 16 additions & 16 deletions test/common_ring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
ode = @ODEmodel(x1'(t) = x2(t), x2'(t) = a * x1(t), y(t) = x1(t))
ioeqs = find_ioequations(ode)
pbr = PBRepresentation(ode, ioeqs)
R, (y_2, y_5, c) = Nemo.PolynomialRing(Nemo.QQ, ["y_2", "y_5", "c"])
R, (y_2, y_5, c) = Nemo.PolynomialRing(Nemo.QQ, ["y(t)_2", "y(t)_5", "c"])
p = y_2^2 + c * y_5
(r, der) = common_ring(p, pbr)
@test Set(map(var_to_str, gens(r))) ==
Set(["y_0", "y_1", "y_2", "y_3", "y_4", "y_5", "c", "a"])
Set(["y(t)_0", "y(t)_1", "y(t)_2", "y(t)_3", "y(t)_4", "y(t)_5", "c", "a"])

ode = @ODEmodel(
x1'(t) = x3(t),
Expand All @@ -17,23 +17,23 @@
)
ioeqs = find_ioequations(ode)
pbr = PBRepresentation(ode, ioeqs)
R, (y1_0, y2_3, u_3) = Nemo.PolynomialRing(Nemo.QQ, ["y1_0", "y2_3", "u_3"])
R, (y1_0, y2_3, u_3) = Nemo.PolynomialRing(Nemo.QQ, ["y1(t)_0", "y2(t)_3", "u(t)_3"])
p = y1_0 + y2_3 + u_3
(r, der) = common_ring(p, pbr)
@test Set([var_to_str(v) for v in gens(r)]) == Set([
"y1_0",
"y1_1",
"y1_2",
"y1_3",
"y1_4",
"y2_0",
"y2_1",
"y2_2",
"y2_3",
"u_0",
"u_1",
"u_2",
"u_3",
"y1(t)_0",
"y1(t)_1",
"y1(t)_2",
"y1(t)_3",
"y1(t)_4",
"y2(t)_0",
"y2(t)_1",
"y2(t)_2",
"y2(t)_3",
"u(t)_0",
"u(t)_1",
"u(t)_2",
"u(t)_3",
"a",
])
end
Loading

0 comments on commit e062848

Please sign in to comment.