Skip to content

Commit

Permalink
creating a dedicated structure for discrete dynamical systems (but so…
Browse files Browse the repository at this point in the history
…me tests are reluctant to pass for the moment)
  • Loading branch information
gpogudin committed Jan 23, 2024
1 parent 2c1e748 commit b655cc0
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 302 deletions.
10 changes: 7 additions & 3 deletions ext/ModelingToolkitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ end
prob_threshold::Float64=0.99)
Input:
- `dds` - the DiscreteSystem object from ModelingToolkit (with **difference** operator in the right-hand side)
- `dds` - the DiscreteSystem object from ModelingToolkit (with **difference** operator in the left-hand side)
- `measured_quantities` - the measurable outputs of the model
- `funcs_to_check` - functions of parameters for which to check identifiability (all parameters and states if not specified)
- `known_ic` - functions (of states and parameter) whose initial conditions are assumed to be known
Expand Down Expand Up @@ -467,14 +467,18 @@ function _assess_local_identifiability(
dds_shift = DiscreteSystem(eqs_shift, name = gensym())
@debug "System transformed from difference to shift: $dds_shift"

Check warning on line 468 in ext/ModelingToolkitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ModelingToolkitExt.jl#L463-L468

Added lines #L463 - L468 were not covered by tests

dds_aux, conversion = mtk_to_si(dds_shift, measured_quantities)
dds_aux_ode, conversion = mtk_to_si(dds_shift, measured_quantities)
dds_aux = StructuralIdentifiability.DDS{QQMPolyRingElem}(dds_aux_ode)
if length(funcs_to_check) == 0
params = parameters(dds)
params_from_measured_quantities = union(
[filter(s -> !istree(s), get_variables(y)) for y in measured_quantities]...,

Check warning on line 475 in ext/ModelingToolkitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ModelingToolkitExt.jl#L470-L475

Added lines #L470 - L475 were not covered by tests
)
funcs_to_check = vcat(

Check warning on line 477 in ext/ModelingToolkitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ModelingToolkitExt.jl#L477

Added line #L477 was not covered by tests
[x for x in states(dds) if conversion[x] in dds_aux.x_vars],
[
x for x in states(dds) if
conversion[x] in StructuralIdentifiability.x_vars(dds_aux)
],
union(params, params_from_measured_quantities),
)
end
Expand Down
242 changes: 0 additions & 242 deletions src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,248 +261,6 @@ end

#------------------------------------------------------------------------------

function _extract_aux!(funcs, all_symb, eq, ders_ok = false)
aux_symb = Set([:(+), :(-), :(=), :(*), :(^), :t, :(/), :(//)])
MacroTools.postwalk(
x -> begin
if @capture(x, f_'(t))
if !ders_ok
throw(
Base.ArgumentError(
"Derivative are not allowed in the right-hand side",
),
)
end
push!(all_symb, f)
elseif @capture(x, f_(t))
push!(funcs, f)
elseif (x isa Symbol) && !(x in aux_symb)
push!(all_symb, x)
end
return x
end,
eq,
)
end

"""
For an expression of the form f'(t) or f(t) returns (f, true) and (f, false), resp
"""
function _get_var(expr)
if @capture(expr, f_'(t))
return (f, true)
end
if @capture(expr, f_(t))
return (f, false)
end
error("cannot extract the single function name from $expr")
end

function macrohelper_extract_vars(equations::Array{Expr, 1})
funcs, all_symb = Set(), Set()
x_vars, y_vars = Vector(), Vector()
aux_symb = Set([:(+), :(-), :(=), :(*), :(^), :t, :(/), :(//)])
for eq in equations
if eq.head != :(=)
_extract_aux!(funcs, all_symb, eq)
else
lhs, rhs = eq.args[1:2]
_extract_aux!(funcs, all_symb, lhs, true)
_extract_aux!(funcs, all_symb, rhs)
(v, is_state) = _get_var(lhs)
if is_state
push!(x_vars, v)
else
push!(y_vars, v)
end
end
end
u_vars = setdiff(funcs, vcat(x_vars, y_vars))
all_symb = collect(all_symb)
return x_vars, y_vars, collect(u_vars), collect(all_symb)
end

function macrohelper_extract_vars(equations::Array{Symbol, 1})
return macrohelper_extract_vars(map(Expr, equations))
end

#------------------------------------------------------------------------------

function macrohelper_clean(ex::Expr)
ex = MacroTools.postwalk(x -> @capture(x, f_'(t)) ? f : x, ex)
ex = MacroTools.postwalk(x -> @capture(x, f_(t)) ? f : x, ex)
ex = MacroTools.postwalk(x -> x == :(/) ? :(//) : x, ex)
ex = MacroTools.postwalk(x -> x isa Float64 ? rationalize(x) : x, ex)
return ex
end

#------------------------------------------------------------------------------

"""
macro ODEmodel
Macro for creating an ODE from a list of equations.
It also injects all variables into the global scope.
## Example
Creating a simple `ODE`:
```jldoctest
using StructuralIdentifiability
ode = @ODEmodel(
x1'(t) = a * x1(t) + u(t),
x2'(t) = b * x2(t) + c*x1(t)*x2(t),
y(t) = x1(t)
)
```
Here,
- `x1`, `x2` are state variables
- `y` is an output variable
- `u` is an input variable
- `a`, `b`, `c` are time-indepdendent parameters
"""
macro ODEmodel(ex::Expr...)
equations = [ex...]
x_vars, y_vars, u_vars, all_symb = macrohelper_extract_vars(equations)
time_dependent = vcat(x_vars, y_vars, u_vars)
params = sort([s for s in all_symb if !(s in time_dependent)])
all_symb_no_t = vcat(time_dependent, params)
all_symb_with_t = vcat([:($s(t)) for s in time_dependent], params)

# creating the polynomial ring
vars_list = :([$(all_symb_with_t...)])
R = gensym()
vars_aux = gensym()
exp_ring = :(
($R, $vars_aux) = StructuralIdentifiability.Nemo.polynomial_ring(
StructuralIdentifiability.Nemo.QQ,
map(string, $all_symb_with_t),
)
)
assignments = [:($(all_symb_no_t[i]) = $vars_aux[$i]) for i in 1:length(all_symb_no_t)]

# setting x_vars and y_vars in the right order
vx = gensym()
vy = gensym()
x_var_expr =
:($vx = Vector{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(x_vars...)]))
y_var_expr =
:($vy = Vector{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(y_vars...)]))

# preparing equations
equations = map(macrohelper_clean, equations)
x_dict = gensym()
y_dict = gensym()
x_dict_create_expr = :(
$x_dict = Dict{
StructuralIdentifiability.Nemo.QQMPolyRingElem,
Union{
StructuralIdentifiability.Nemo.QQMPolyRingElem,
StructuralIdentifiability.AbstractAlgebra.Generic.Frac{
StructuralIdentifiability.Nemo.QQMPolyRingElem,
},
},
}()
)
y_dict_create_expr = :(
$y_dict = Dict{
StructuralIdentifiability.Nemo.QQMPolyRingElem,
Union{
StructuralIdentifiability.Nemo.QQMPolyRingElem,
StructuralIdentifiability.AbstractAlgebra.Generic.Frac{
StructuralIdentifiability.Nemo.QQMPolyRingElem,
},
},
}()
)
eqs_expr = []
for eq in equations
if eq.head != :(=)
throw("Problem with parsing at $eq")
end
lhs, rhs = eq.args[1:2]
loc_all_symb = macrohelper_extract_vars([rhs])[4]
to_insert = undef
if lhs in x_vars
to_insert = x_dict
elseif lhs in y_vars
to_insert = y_dict
else
throw("Unknown left-hand side $lhs")
end

uniqueness_check_expr = quote
if haskey($to_insert, $lhs)
throw(
DomainError(
$lhs,
"The variable occurs twice in the left-hand-side of the ODE system",
),
)
end
end
push!(eqs_expr, uniqueness_check_expr)
if isempty(loc_all_symb)
push!(eqs_expr, :($to_insert[$lhs] = $R($rhs)))
else
push!(eqs_expr, :($to_insert[$lhs] = ($rhs)))
end
end

for n in all_symb_no_t
if !Base.isidentifier(n)
throw(
ArgumentError(
"The names of the variables will be injected into the global scope, so their name must be allowed Julia names, $n is not",
),
)
end
end

logging_exprs = [
:(
StructuralIdentifiability.Logging.with_logger(
StructuralIdentifiability._si_logger[],
) do
@info "Summary of the model:"
@info "State variables: " * $(join(map(string, collect(x_vars)), ", "))
@info "Parameters: " * $(join(map(string, collect(params)), ", "))
@info "Inputs: " * $(join(map(string, collect(u_vars)), ", "))
@info "Outputs: " * $(join(map(string, collect(y_vars)), ", "))
end
),
]
# creating the ode object
ode_expr =
:(StructuralIdentifiability.ODE{StructuralIdentifiability.Nemo.QQMPolyRingElem}(
$vx,
$vy,
$x_dict,
$y_dict,
Array{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(u_vars...)]),
))

result = Expr(
:block,
logging_exprs...,
exp_ring,
assignments...,
x_var_expr,
y_var_expr,
x_dict_create_expr,
y_dict_create_expr,
eqs_expr...,
ode_expr,
)
return esc(result)
end

#------------------------------------------------------------------------------

function Base.show(io::IO, ode::ODE)
for x in ode.x_vars
if endswith(var_to_str(x), "(t)")
Expand Down
3 changes: 2 additions & 1 deletion src/StructuralIdentifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using ParamPunPam: reduce_mod_p!, specialize_mod_p, AbstractBlackboxIdeal
ParamPunPam.enable_progressbar(false)

# defining a model
export ODE, @ODEmodel, mtk_to_si
export ODE, @ODEmodel, @DDSmodel, mtk_to_si

# assessing identifiability
export assess_local_identifiability, assess_identifiability
Expand Down Expand Up @@ -70,6 +70,7 @@ include("lincomp.jl")
include("pb_representation.jl")
include("submodels.jl")
include("discrete.jl")
include("input_macro.jl")

function __init__()
_si_logger[] = @static if VERSION >= v"1.7.0"
Expand Down
Loading

0 comments on commit b655cc0

Please sign in to comment.