diff --git a/ext/ModelingToolkitExt.jl b/ext/ModelingToolkitExt.jl index 0e1a0002..4ca8de85 100644 --- a/ext/ModelingToolkitExt.jl +++ b/ext/ModelingToolkitExt.jl @@ -404,7 +404,7 @@ end prob_threshold::Float64=0.99) Input: -- `dds` - the DiscreteSystem object from ModelingToolkit (with **difference** operator in the right-hand side) +- `dds` - the DiscreteSystem object from ModelingToolkit (with **difference** operator in the left-hand side) - `measured_quantities` - the measurable outputs of the model - `funcs_to_check` - functions of parameters for which to check identifiability (all parameters and states if not specified) - `known_ic` - functions (of states and parameter) whose initial conditions are assumed to be known @@ -467,14 +467,18 @@ function _assess_local_identifiability( dds_shift = DiscreteSystem(eqs_shift, name = gensym()) @debug "System transformed from difference to shift: $dds_shift" - dds_aux, conversion = mtk_to_si(dds_shift, measured_quantities) + dds_aux_ode, conversion = mtk_to_si(dds_shift, measured_quantities) + dds_aux = StructuralIdentifiability.DDS{QQMPolyRingElem}(dds_aux_ode) if length(funcs_to_check) == 0 params = parameters(dds) params_from_measured_quantities = union( [filter(s -> !istree(s), get_variables(y)) for y in measured_quantities]..., ) funcs_to_check = vcat( - [x for x in states(dds) if conversion[x] in dds_aux.x_vars], + [ + x for x in states(dds) if + conversion[x] in StructuralIdentifiability.x_vars(dds_aux) + ], union(params, params_from_measured_quantities), ) end diff --git a/src/ODE.jl b/src/ODE.jl index 83ba44ae..addf9f93 100644 --- a/src/ODE.jl +++ b/src/ODE.jl @@ -261,248 +261,6 @@ end #------------------------------------------------------------------------------ -function _extract_aux!(funcs, all_symb, eq, ders_ok = false) - aux_symb = Set([:(+), :(-), :(=), :(*), :(^), :t, :(/), :(//)]) - MacroTools.postwalk( - x -> begin - if @capture(x, f_'(t)) - if !ders_ok - throw( - Base.ArgumentError( - "Derivative are not allowed in the right-hand side", - ), - ) - end - push!(all_symb, f) - elseif @capture(x, f_(t)) - push!(funcs, f) - elseif (x isa Symbol) && !(x in aux_symb) - push!(all_symb, x) - end - return x - end, - eq, - ) -end - -""" - For an expression of the form f'(t) or f(t) returns (f, true) and (f, false), resp -""" -function _get_var(expr) - if @capture(expr, f_'(t)) - return (f, true) - end - if @capture(expr, f_(t)) - return (f, false) - end - error("cannot extract the single function name from $expr") -end - -function macrohelper_extract_vars(equations::Array{Expr, 1}) - funcs, all_symb = Set(), Set() - x_vars, y_vars = Vector(), Vector() - aux_symb = Set([:(+), :(-), :(=), :(*), :(^), :t, :(/), :(//)]) - for eq in equations - if eq.head != :(=) - _extract_aux!(funcs, all_symb, eq) - else - lhs, rhs = eq.args[1:2] - _extract_aux!(funcs, all_symb, lhs, true) - _extract_aux!(funcs, all_symb, rhs) - (v, is_state) = _get_var(lhs) - if is_state - push!(x_vars, v) - else - push!(y_vars, v) - end - end - end - u_vars = setdiff(funcs, vcat(x_vars, y_vars)) - all_symb = collect(all_symb) - return x_vars, y_vars, collect(u_vars), collect(all_symb) -end - -function macrohelper_extract_vars(equations::Array{Symbol, 1}) - return macrohelper_extract_vars(map(Expr, equations)) -end - -#------------------------------------------------------------------------------ - -function macrohelper_clean(ex::Expr) - ex = MacroTools.postwalk(x -> @capture(x, f_'(t)) ? f : x, ex) - ex = MacroTools.postwalk(x -> @capture(x, f_(t)) ? f : x, ex) - ex = MacroTools.postwalk(x -> x == :(/) ? :(//) : x, ex) - ex = MacroTools.postwalk(x -> x isa Float64 ? rationalize(x) : x, ex) - return ex -end - -#------------------------------------------------------------------------------ - -""" - macro ODEmodel - -Macro for creating an ODE from a list of equations. -It also injects all variables into the global scope. - -## Example - -Creating a simple `ODE`: - -```jldoctest -using StructuralIdentifiability - -ode = @ODEmodel( - x1'(t) = a * x1(t) + u(t), - x2'(t) = b * x2(t) + c*x1(t)*x2(t), - y(t) = x1(t) -) -``` - -Here, -- `x1`, `x2` are state variables -- `y` is an output variable -- `u` is an input variable -- `a`, `b`, `c` are time-indepdendent parameters - -""" -macro ODEmodel(ex::Expr...) - equations = [ex...] - x_vars, y_vars, u_vars, all_symb = macrohelper_extract_vars(equations) - time_dependent = vcat(x_vars, y_vars, u_vars) - params = sort([s for s in all_symb if !(s in time_dependent)]) - all_symb_no_t = vcat(time_dependent, params) - all_symb_with_t = vcat([:($s(t)) for s in time_dependent], params) - - # creating the polynomial ring - vars_list = :([$(all_symb_with_t...)]) - R = gensym() - vars_aux = gensym() - exp_ring = :( - ($R, $vars_aux) = StructuralIdentifiability.Nemo.polynomial_ring( - StructuralIdentifiability.Nemo.QQ, - map(string, $all_symb_with_t), - ) - ) - assignments = [:($(all_symb_no_t[i]) = $vars_aux[$i]) for i in 1:length(all_symb_no_t)] - - # setting x_vars and y_vars in the right order - vx = gensym() - vy = gensym() - x_var_expr = - :($vx = Vector{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(x_vars...)])) - y_var_expr = - :($vy = Vector{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(y_vars...)])) - - # preparing equations - equations = map(macrohelper_clean, equations) - x_dict = gensym() - y_dict = gensym() - x_dict_create_expr = :( - $x_dict = Dict{ - StructuralIdentifiability.Nemo.QQMPolyRingElem, - Union{ - StructuralIdentifiability.Nemo.QQMPolyRingElem, - StructuralIdentifiability.AbstractAlgebra.Generic.Frac{ - StructuralIdentifiability.Nemo.QQMPolyRingElem, - }, - }, - }() - ) - y_dict_create_expr = :( - $y_dict = Dict{ - StructuralIdentifiability.Nemo.QQMPolyRingElem, - Union{ - StructuralIdentifiability.Nemo.QQMPolyRingElem, - StructuralIdentifiability.AbstractAlgebra.Generic.Frac{ - StructuralIdentifiability.Nemo.QQMPolyRingElem, - }, - }, - }() - ) - eqs_expr = [] - for eq in equations - if eq.head != :(=) - throw("Problem with parsing at $eq") - end - lhs, rhs = eq.args[1:2] - loc_all_symb = macrohelper_extract_vars([rhs])[4] - to_insert = undef - if lhs in x_vars - to_insert = x_dict - elseif lhs in y_vars - to_insert = y_dict - else - throw("Unknown left-hand side $lhs") - end - - uniqueness_check_expr = quote - if haskey($to_insert, $lhs) - throw( - DomainError( - $lhs, - "The variable occurs twice in the left-hand-side of the ODE system", - ), - ) - end - end - push!(eqs_expr, uniqueness_check_expr) - if isempty(loc_all_symb) - push!(eqs_expr, :($to_insert[$lhs] = $R($rhs))) - else - push!(eqs_expr, :($to_insert[$lhs] = ($rhs))) - end - end - - for n in all_symb_no_t - if !Base.isidentifier(n) - throw( - ArgumentError( - "The names of the variables will be injected into the global scope, so their name must be allowed Julia names, $n is not", - ), - ) - end - end - - logging_exprs = [ - :( - StructuralIdentifiability.Logging.with_logger( - StructuralIdentifiability._si_logger[], - ) do - @info "Summary of the model:" - @info "State variables: " * $(join(map(string, collect(x_vars)), ", ")) - @info "Parameters: " * $(join(map(string, collect(params)), ", ")) - @info "Inputs: " * $(join(map(string, collect(u_vars)), ", ")) - @info "Outputs: " * $(join(map(string, collect(y_vars)), ", ")) - end - ), - ] - # creating the ode object - ode_expr = - :(StructuralIdentifiability.ODE{StructuralIdentifiability.Nemo.QQMPolyRingElem}( - $vx, - $vy, - $x_dict, - $y_dict, - Array{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(u_vars...)]), - )) - - result = Expr( - :block, - logging_exprs..., - exp_ring, - assignments..., - x_var_expr, - y_var_expr, - x_dict_create_expr, - y_dict_create_expr, - eqs_expr..., - ode_expr, - ) - return esc(result) -end - -#------------------------------------------------------------------------------ - function Base.show(io::IO, ode::ODE) for x in ode.x_vars if endswith(var_to_str(x), "(t)") diff --git a/src/StructuralIdentifiability.jl b/src/StructuralIdentifiability.jl index 2793ce16..a18c9828 100644 --- a/src/StructuralIdentifiability.jl +++ b/src/StructuralIdentifiability.jl @@ -21,7 +21,7 @@ using ParamPunPam: reduce_mod_p!, specialize_mod_p, AbstractBlackboxIdeal ParamPunPam.enable_progressbar(false) # defining a model -export ODE, @ODEmodel, mtk_to_si +export ODE, @ODEmodel, @DDSmodel, mtk_to_si # assessing identifiability export assess_local_identifiability, assess_identifiability @@ -70,6 +70,7 @@ include("lincomp.jl") include("pb_representation.jl") include("submodels.jl") include("discrete.jl") +include("input_macro.jl") function __init__() _si_logger[] = @static if VERSION >= v"1.7.0" diff --git a/src/discrete.jl b/src/discrete.jl index 9d911793..aa3ea73c 100644 --- a/src/discrete.jl +++ b/src/discrete.jl @@ -1,8 +1,95 @@ +""" +The structue to represent a discrete dynamical system +with respect to *shift*. Internally just stores an ODE structur + +Can be constructed with @DDSmodel macro +""" +struct DDS{P} # P is the type of polynomials in the rhs of the DDS system + ode::ODE{P} + + function DDS{P}( + x_vars::Array{P, 1}, + y_vars::Array{P, 1}, + x_eqs::Dict{P, <:Union{P, Generic.Frac{P}}}, + y_eqs::Dict{P, <:Union{P, Generic.Frac{P}}}, + inputs::Array{P, 1}, + ) where {P <: MPolyRingElem{<:FieldElem}} + new{P}(ODE{P}(x_vars, y_vars, x_eqs, y_eqs, inputs)) + end + + function DDS{P}(ode::ODE{P}) where {P <: MPolyRingElem{<:FieldElem}} + new{P}(ode) + end +end + +#------------------------------------------------------------------------------ + +# getters + +function x_vars(dds::DDS) + return dds.ode.x_vars +end + +function y_vars(dds::DDS) + return dds.ode.y_vars +end + +function parameters(dds::DDS) + return dds.ode.parameters +end + +function inputs(dds::DDS) + return dds.ode.u_vars +end + +function x_equations(dds::DDS) + return dds.ode.x_equations +end + +function y_equations(dds::DDS) + return dds.ode.y_equations +end + +function Base.parent(dds::DDS) + return parent(dds.ode) +end + +#------------------------------------------------------------------------------ +# Some functions to transform DDS's + +function add_outputs( + dds::DDS{P}, + extra_y::Dict{String, <:RingElem}, +) where {P <: MPolyRingElem} + return DDS{P}(add_outputs(dds.ode, extra_y)) +end + +#------------------------------------------------------------------------------ + +function Base.show(io::IO, dds::DDS) + for x in x_vars(dds) + if endswith(var_to_str(x), "(t)") + print(io, var_to_str(x)[1:(end - 3)] * "(t + 1) = ") + else + print(io, var_to_str(x) * "(t + 1) = ") + end + print(io, x_equations(dds)[x]) + print(io, "\n") + end + for y in y_vars(dds) + print(io, var_to_str(y) * " = ") + print(io, y_equations(dds)[y]) + print(io, "\n") + end +end + +#------------------------------------------------------------------------------ + """ sequence_solution(dds, param_values, initial_conditions, input_values, num_terms) Input: -- `dds` - a discrete dynamical system to solve (represented as an ODE struct) +- `dds` - a discrete dynamical system to solve - `param_values` - parameter values, must be a dictionary mapping parameter to a value - `initial_conditions` - initial conditions of `ode`, must be a dictionary mapping state variable to a value - `input_values` - input sequences in the form input => list of terms; length of the lists must be at least @@ -13,21 +100,21 @@ Output: - computes a sequence solution with teh required number of terms prec presented as a dictionary state_variable => corresponding sequence """ function sequence_solution( - dds::ODE{P}, + dds::DDS{P}, param_values::Dict{P, T}, initial_conditions::Dict{P, T}, input_values::Dict{P, Array{T, 1}}, num_terms::Int, ) where {T <: FieldElem, P <: MPolyRingElem{T}} - result = Dict(x => [initial_conditions[x]] for x in dds.x_vars) + result = Dict(x => [initial_conditions[x]] for x in x_vars(dds)) for i in 2:num_terms eval_dict = merge( param_values, Dict(k => v[end] for (k, v) in result), Dict(u => val[i - 1] for (u, val) in input_values), ) - for x in dds.x_vars - push!(result[x], eval_at_dict(dds.x_equations[x], eval_dict)) + for x in x_vars(dds) + push!(result[x], eval_at_dict(x_equations(dds)[x], eval_dict)) end end return result @@ -36,13 +123,13 @@ end #------------------------------------------------------------------------------ function sequence_solution( - dds::ODE{P}, + dds::DDS{P}, param_values::Dict{P, Int}, initial_conditions::Dict{P, Int}, input_values::Dict{P, Array{Int, 1}}, num_terms::Int, ) where {P <: MPolyRingElem{<:FieldElem}} - bring = base_ring(dds.poly_ring) + bring = base_ring(parent(dds)) return sequence_solution( dds, Dict(p => bring(v) for (p, v) in param_values), @@ -66,7 +153,7 @@ Output: the function `u` w.r.t. `v` evaluated at the solution """ function differentiate_sequence_solution( - dds::ODE{P}, + dds::DDS{P}, params::Dict{P, T}, ic::Dict{P, T}, inputs::Dict{P, Array{T, 1}}, @@ -74,16 +161,16 @@ function differentiate_sequence_solution( ) where {T <: Generic.FieldElem, P <: MPolyRingElem{T}} @debug "Computing the power series solution of the system" seq_sol = sequence_solution(dds, params, ic, inputs, num_terms) - generalized_params = vcat(dds.x_vars, dds.parameters) - bring = base_ring(dds.poly_ring) + generalized_params = vcat(x_vars(dds), parameters(dds)) + bring = base_ring(parent(dds)) @debug "Solving the variational system at the solution" part_diffs = Dict( - (x, p) => derivative(dds.x_equations[x], p) for x in dds.x_vars for + (x, p) => derivative(x_equations(dds)[x], p) for x in x_vars(dds) for p in generalized_params ) result = Dict( - (x, p) => [x == p ? one(bring) : zero(bring)] for x in dds.x_vars for + (x, p) => [x == p ? one(bring) : zero(bring)] for x in x_vars(dds) for p in generalized_params ) for i in 2:num_terms @@ -93,13 +180,13 @@ function differentiate_sequence_solution( Dict(u => val[i - 1] for (u, val) in inputs), ) for p in generalized_params - local_eval = Dict(x => result[(x, p)][end] for x in dds.x_vars) - for x in dds.x_vars + local_eval = Dict(x => result[(x, p)][end] for x in x_vars(dds)) + for x in x_vars(dds) res = sum([ eval_at_dict(part_diffs[(x, x2)], eval_dict) * local_eval[x2] for - x2 in dds.x_vars + x2 in x_vars(dds) ]) - if p in dds.parameters + if p in parameters(dds) res += eval_at_dict(part_diffs[(x, p)], eval_dict) end push!(result[(x, p)], res) @@ -120,7 +207,7 @@ returns a dictionary of the form `y_function => Dict(var => dy/dvar)` where `dy/ of `y_function` with respect to `var`. """ function differentiate_sequence_output( - dds::ODE{P}, + dds::DDS{P}, params::Dict{P, T}, ic::Dict{P, T}, inputs::Dict{P, Array{T, 1}}, @@ -130,13 +217,13 @@ function differentiate_sequence_output( seq_sol, sol_diff = differentiate_sequence_solution(dds, params, ic, inputs, num_terms) @debug "Evaluating the partial derivatives of the outputs" - generalized_params = vcat(dds.x_vars, dds.parameters) + generalized_params = vcat(x_vars(dds), parameters(dds)) part_diffs = Dict( - (y, p) => derivative(dds.y_equations[y], p) for y in dds.y_vars for + (y, p) => derivative(y_equations(dds)[y], p) for y in y_vars(dds) for p in generalized_params ) - result = Dict((y, p) => [] for y in dds.y_vars for p in generalized_params) + result = Dict((y, p) => [] for y in y_vars(dds) for p in generalized_params) for i in 1:num_terms eval_dict = merge( params, @@ -144,14 +231,14 @@ function differentiate_sequence_output( Dict(u => val[i] for (u, val) in inputs), ) - for p in vcat(dds.x_vars, dds.parameters) - local_eval = Dict(x => sol_diff[(x, p)][i] for x in dds.x_vars) - for (y, y_eq) in dds.y_equations + for p in generalized_params + local_eval = Dict(x => sol_diff[(x, p)][i] for x in x_vars(dds)) + for (y, y_eq) in y_equations(dds) res = sum([ eval_at_dict(part_diffs[(y, x)], eval_dict) * local_eval[x] for - x in dds.x_vars + x in x_vars(dds) ]) - if p in dds.parameters + if p in parameters(dds) res += eval_at_dict(part_diffs[(y, p)], eval_dict) end push!(result[(y, p)], res) @@ -178,10 +265,9 @@ function _degree_with_common_denom(polys) end """ - _assess_local_identifiability_discrete_aux(dds::ODE{P}, funcs_to_check::Array{<: Any, 1}, known_ic, prob_threshold::Float64=0.99) where P <: MPolyRingElem{Nemo.QQFieldElem} + _assess_local_identifiability_discrete_aux(dds::DDS{P}, funcs_to_check::Array{<: Any, 1}, known_ic, prob_threshold::Float64=0.99) where P <: MPolyRingElem{Nemo.QQFieldElem} -Checks the local identifiability/observability of the functions in `funcs_to_check` treating `dds` as a discrete-time system with **shift** -instead of derivative in the right-hand side. +Checks the local identifiability/observability of the functions in `funcs_to_check`. The result is correct with probability at least `prob_threshold`. `known_ic` can take one of the following * `:none` - no initial conditions are assumed to be known @@ -189,12 +275,12 @@ The result is correct with probability at least `prob_threshold`. * a list of rational functions in states and parameters assumed to be known at t = 0 """ function _assess_local_identifiability_discrete_aux( - dds::ODE{P}, + dds::DDS{P}, funcs_to_check::Array{<:Any, 1}, known_ic = :none, prob_threshold::Float64 = 0.99, ) where {P <: MPolyRingElem{Nemo.QQFieldElem}} - bring = base_ring(dds.poly_ring) + bring = base_ring(parent(dds)) @debug "Extending the model" dds_ext = @@ -204,16 +290,16 @@ function _assess_local_identifiability_discrete_aux( known_ic = [] end if known_ic == :all - known_ic = dds_ext.x_vars + known_ic = x_vars(dds_ext) end @debug "Computing the observability matrix" - prec = length(dds.x_vars) + length(dds.parameters) + prec = length(x_vars(dds)) + length(x_vars(dds)) @debug "The truncation order is $prec" # Computing the bound from the Schwartz-Zippel-DeMilo-Lipton lemma - deg_x = _degree_with_common_denom(values(dds.x_equations)) - deg_y = _degree_with_common_denom(values(dds.y_equations)) + deg_x = _degree_with_common_denom(values(x_equations(dds))) + deg_y = _degree_with_common_denom(values(y_equations(dds))) deg_known = reduce(+, map(total_degree, known_ic), init = 0) deg_to_check = max(map(total_degree, funcs_to_check)...) Jac_degree = deg_to_check + deg_known @@ -226,11 +312,12 @@ function _assess_local_identifiability_discrete_aux( @debug "Sampling range $D" # Parameter values are the same across all the replicas - params_vals = Dict(p => bring(rand(1:D)) for p in dds_ext.parameters) - ic = Dict(x => bring(rand(1:D)) for x in dds_ext.x_vars) + params_vals = Dict(p => bring(rand(1:D)) for p in parameters(dds_ext)) + ic = Dict(x => bring(rand(1:D)) for x in x_vars(dds_ext)) # TODO: parametric type instead of QQFieldElem inputs = Dict{P, Array{QQFieldElem, 1}}( - u => [bring(rand(1:D)) for i in 1:prec] for u in dds_ext.u_vars + u => [bring(rand(1:D)) for i in 1:prec] for + u in StructuralIdentifiability.inputs(dds_ext) ) @debug "Computing the output derivatives" @@ -241,28 +328,28 @@ function _assess_local_identifiability_discrete_aux( Jac = zero( Nemo.matrix_space( bring, - length(dds.x_vars) + length(dds.parameters), - 1 + prec * length(dds.y_vars) + length(known_ic), + length(x_vars(dds)) + length(parameters(dds)), + 1 + prec * length(y_vars(dds)) + length(known_ic), ), ) - xs_params = vcat(dds_ext.x_vars, dds_ext.parameters) - for (i, y) in enumerate(dds.y_vars) - y = switch_ring(y, dds_ext.poly_ring) + xs_params = vcat(x_vars(dds_ext), parameters(dds_ext)) + for (i, y) in enumerate(y_vars(dds)) + y = switch_ring(y, parent(dds_ext)) for j in 1:prec - for (k, p) in enumerate(dds_ext.parameters) + for (k, p) in enumerate(parameters(dds_ext)) Jac[k, 1 + (i - 1) * prec + j] = output_derivatives[(y, p)][j] end - for (k, x) in enumerate(dds_ext.x_vars) + for (k, x) in enumerate(x_vars(dds_ext)) Jac[end - k + 1, 1 + (i - 1) * prec + j] = output_derivatives[(y, x)][j] end end end eval_point = merge(params_vals, ic) for (i, v) in enumerate(known_ic) - for (k, p) in enumerate(dds_ext.parameters) + for (k, p) in enumerate(parameters(dds_ext)) Jac[k, end - i + 1] = eval_at_dict(derivative(v, p), eval_point) end - for (k, x) in enumerate(dds_ext.x_vars) + for (k, x) in enumerate(x_vars(dds_ext)) Jac[end - k + 1, end - i + 1] = eval_at_dict(derivative(v, x), eval_point) end end @@ -271,13 +358,13 @@ function _assess_local_identifiability_discrete_aux( base_rank = LinearAlgebra.rank(Jac) result = OrderedDict{Any, Bool}() for i in 1:length(funcs_to_check) - for (k, p) in enumerate(dds_ext.parameters) + for (k, p) in enumerate(parameters(dds_ext)) Jac[k, 1] = - output_derivatives[(str_to_var("loc_aux_$i", dds_ext.poly_ring), p)][1] + output_derivatives[(str_to_var("loc_aux_$i", parent(dds_ext)), p)][1] end - for (k, x) in enumerate(dds_ext.x_vars) + for (k, x) in enumerate(x_vars(dds_ext)) Jac[end - k + 1, 1] = - output_derivatives[(str_to_var("loc_aux_$i", dds_ext.poly_ring), x)][1] + output_derivatives[(str_to_var("loc_aux_$i", parent(dds_ext)), x)][1] end result[funcs_to_check[i]] = LinearAlgebra.rank(Jac) == base_rank end diff --git a/src/input_macro.jl b/src/input_macro.jl new file mode 100644 index 00000000..2a2cf16f --- /dev/null +++ b/src/input_macro.jl @@ -0,0 +1,309 @@ +function _extract_aux!(funcs, all_symb, eq; ders_ok = false, type = :ode) + aux_symb = Set([:(+), :(-), :(=), :(*), :(^), :t, :(/), :(//)]) + MacroTools.postwalk( + x -> begin + if @capture(x, f_'(t)) + if !ders_ok + throw( + Base.ArgumentError( + "Derivative are not allowed in the right-hand side", + ), + ) + end + if type != :ode + throw( + Base.ArgumentError( + "Derivative are not expected in the discrete case", + ), + ) + end + push!(all_symb, f) + elseif @capture(x, f_(t + 1)) + if !ders_ok + throw(Base.ArgumentError("Shifts are not allowed in the right-hand side")) + end + if type != :dds + throw( + Base.ArgumentError( + "Shifts are not expected in the differential case", + ), + ) + end + push!(all_symb, f) + elseif @capture(x, f_(t)) + push!(funcs, f) + elseif (x isa Symbol) && !(x in aux_symb) + push!(all_symb, x) + end + return x + end, + eq, + ) +end + +""" + For an expression of the form f'(t)/f(t + 1) or f(t) returns (f, true) and (f, false), resp +""" +function _get_var(expr, type = :ode) + if @capture(expr, f_'(t)) + @assert type == :ode + return (f, true) + end + if @capture(expr, f_(t + 1)) + @assert type == :dds + return (f, true) + end + if @capture(expr, f_(t)) + return (f, false) + end + error("cannot extract the single function name from $expr") +end + +function macrohelper_extract_vars(equations::Array{Expr, 1}, type = :ode) + funcs, all_symb = Set(), Set() + x_vars, y_vars = Vector(), Vector() + aux_symb = Set([:(+), :(-), :(=), :(*), :(^), :t, :(/), :(//)]) + for eq in equations + if eq.head != :(=) + _extract_aux!(funcs, all_symb, eq, type = type) + else + lhs, rhs = eq.args[1:2] + _extract_aux!(funcs, all_symb, lhs, ders_ok = true, type = type) + _extract_aux!(funcs, all_symb, rhs, type = type) + (v, is_state) = _get_var(lhs, type) + if is_state + push!(x_vars, v) + else + push!(y_vars, v) + end + end + end + u_vars = setdiff(funcs, vcat(x_vars, y_vars)) + all_symb = collect(all_symb) + return x_vars, y_vars, collect(u_vars), collect(all_symb) +end + +function macrohelper_extract_vars(equations::Array{Symbol, 1}, type = :ode) + return macrohelper_extract_vars(map(Expr, equations), type) +end + +#------------------------------------------------------------------------------ + +function macrohelper_clean(ex::Expr) + ex = MacroTools.postwalk(x -> @capture(x, f_'(t)) ? f : x, ex) + ex = MacroTools.postwalk(x -> @capture(x, f_(t + 1)) ? f : x, ex) + ex = MacroTools.postwalk(x -> @capture(x, f_(t)) ? f : x, ex) + ex = MacroTools.postwalk(x -> x == :(/) ? :(//) : x, ex) + ex = MacroTools.postwalk(x -> x isa Float64 ? rationalize(x) : x, ex) + return ex +end + +#------------------------------------------------------------------------------ + +function generate_model_code(type, ex::Expr...) + @assert type in (:ode, :dds) + equations = [ex...] + x_vars, y_vars, u_vars, all_symb = macrohelper_extract_vars(equations, type) + time_dependent = vcat(x_vars, y_vars, u_vars) + params = sort([s for s in all_symb if !(s in time_dependent)]) + all_symb_no_t = vcat(time_dependent, params) + all_symb_with_t = vcat([:($s(t)) for s in time_dependent], params) + + # creating the polynomial ring + vars_list = :([$(all_symb_with_t...)]) + R = gensym() + vars_aux = gensym() + exp_ring = :( + ($R, $vars_aux) = StructuralIdentifiability.Nemo.polynomial_ring( + StructuralIdentifiability.Nemo.QQ, + map(string, $all_symb_with_t), + ) + ) + assignments = [:($(all_symb_no_t[i]) = $vars_aux[$i]) for i in 1:length(all_symb_no_t)] + + # setting x_vars and y_vars in the right order + vx = gensym() + vy = gensym() + x_var_expr = + :($vx = Vector{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(x_vars...)])) + y_var_expr = + :($vy = Vector{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(y_vars...)])) + + # preparing equations + equations = map(macrohelper_clean, equations) + x_dict = gensym() + y_dict = gensym() + x_dict_create_expr = :( + $x_dict = Dict{ + StructuralIdentifiability.Nemo.QQMPolyRingElem, + Union{ + StructuralIdentifiability.Nemo.QQMPolyRingElem, + StructuralIdentifiability.AbstractAlgebra.Generic.Frac{ + StructuralIdentifiability.Nemo.QQMPolyRingElem, + }, + }, + }() + ) + y_dict_create_expr = :( + $y_dict = Dict{ + StructuralIdentifiability.Nemo.QQMPolyRingElem, + Union{ + StructuralIdentifiability.Nemo.QQMPolyRingElem, + StructuralIdentifiability.AbstractAlgebra.Generic.Frac{ + StructuralIdentifiability.Nemo.QQMPolyRingElem, + }, + }, + }() + ) + eqs_expr = [] + for eq in equations + if eq.head != :(=) + throw("Problem with parsing at $eq") + end + lhs, rhs = eq.args[1:2] + loc_all_symb = macrohelper_extract_vars([rhs], type)[4] + to_insert = undef + if lhs in x_vars + to_insert = x_dict + elseif lhs in y_vars + to_insert = y_dict + else + throw("Unknown left-hand side $lhs") + end + + uniqueness_check_expr = quote + if haskey($to_insert, $lhs) + throw( + DomainError( + $lhs, + "The variable occurs twice in the left-hand-side of the ODE system", + ), + ) + end + end + push!(eqs_expr, uniqueness_check_expr) + if isempty(loc_all_symb) + push!(eqs_expr, :($to_insert[$lhs] = $R($rhs))) + else + push!(eqs_expr, :($to_insert[$lhs] = ($rhs))) + end + end + + for n in all_symb_no_t + if !Base.isidentifier(n) + throw( + ArgumentError( + "The names of the variables will be injected into the global scope, so their name must be allowed Julia names, $n is not", + ), + ) + end + end + + logging_exprs = [ + :( + StructuralIdentifiability.Logging.with_logger( + StructuralIdentifiability._si_logger[], + ) do + @info "Summary of the model:" + @info "State variables: " * $(join(map(string, collect(x_vars)), ", ")) + @info "Parameters: " * $(join(map(string, collect(params)), ", ")) + @info "Inputs: " * $(join(map(string, collect(u_vars)), ", ")) + @info "Outputs: " * $(join(map(string, collect(y_vars)), ", ")) + end + ), + ] + # creating the ode/dds object + obj_type = Dict(:ode => :ODE, :dds => :DDS) + ds_expr = :(StructuralIdentifiability.$(obj_type[type]){ + StructuralIdentifiability.Nemo.QQMPolyRingElem, + }( + $vx, + $vy, + $x_dict, + $y_dict, + Array{StructuralIdentifiability.Nemo.QQMPolyRingElem}([$(u_vars...)]), + )) + + result = Expr( + :block, + logging_exprs..., + exp_ring, + assignments..., + x_var_expr, + y_var_expr, + x_dict_create_expr, + y_dict_create_expr, + eqs_expr..., + ds_expr, + ) + return result +end + +#------------------------------------------------------------------------------ + +""" + macro ODEmodel + +Macro for creating an ODE from a list of equations. +It also injects all variables into the global scope. + +## Example + +Creating a simple `ODE`: + +```jldoctest +using StructuralIdentifiability + +ode = @ODEmodel( + x1'(t) = a * x1(t) + u(t), + x2'(t) = b * x2(t) + c*x1(t)*x2(t), + y(t) = x1(t) +) +``` + +Here, +- `x1`, `x2` are state variables +- `y` is an output variable +- `u` is an input variable +- `a`, `b`, `c` are time-indepdendent parameters + +""" +macro ODEmodel(ex::Expr...) + return esc(generate_model_code(:ode, ex...)) +end + +#------------------------------------------------------------------------------ + +""" + macro DDSmodel + +Macro for creating a DDS (discrete dynamical system) +from a list of equations. +It also injects all variables into the global scope. + +## Example + +Creating a simple `DDS`: + +```jldoctest +using StructuralIdentifiability + +ode = @ODEmodel( + x1(t + 1) = a * x1(t) + u(t), + x2(t + 1) = b * x2(t) + c*x1(t)*x2(t), + y(t) = x1(t) +) +``` + +Here, +- `x1`, `x2` are state variables +- `y` is an output variable +- `u` is an input variable +- `a`, `b`, `c` are time-indepdendent parameters + +""" +macro DDSmodel(ex::Expr...) + return esc(generate_model_code(:dds, ex...)) +end + +#------------------------------------------------------------------------------ diff --git a/test/extensions/modelingtoolkit.jl b/test/extensions/modelingtoolkit.jl index bfafa363..c6299879 100644 --- a/test/extensions/modelingtoolkit.jl +++ b/test/extensions/modelingtoolkit.jl @@ -437,14 +437,14 @@ correct end - @testset "Discrete local identifiability, internal function" begin + @testset "Discrete local identifiability, ModelingToolkit interface" begin cases = [] @parameters α β @variables t S(t) I(t) R(t) y(t) D = Difference(t; dt = 1.0) - eqs = [D(S) ~ S - β * S * I, D(I) ~ I + β * S * I - α * I, D(R) ~ R + α * I] + eqs = [D(S) ~ -β * S * I, D(I) ~ β * S * I - α * I, D(R) ~ α * I] @named sir = DiscreteSystem(eqs) push!( cases, @@ -463,7 +463,7 @@ @variables t x(t) y(t) D = Difference(t; dt = 1.0) - eqs = [D(x) ~ θ * x^3] + eqs = [D(x) ~ θ * x^3 - x] @named eqs = DiscreteSystem(eqs) push!( diff --git a/test/local_identifiability_discrete_aux.jl b/test/local_identifiability_discrete_aux.jl index 11981614..0c41ad9d 100644 --- a/test/local_identifiability_discrete_aux.jl +++ b/test/local_identifiability_discrete_aux.jl @@ -2,7 +2,7 @@ if GROUP == "All" || GROUP == "Core" @testset "Discrete local identifiability, internal function" begin cases = [] - dds = @ODEmodel(a'(t) = (b + c) * a(t) + 1, y(t) = a(t)) + dds = @DDSmodel(a(t + 1) = (b + c) * a(t) + 1, y(t) = a(t)) push!( cases, @@ -24,7 +24,7 @@ if GROUP == "All" || GROUP == "Core" #--------------------- - dds = @ODEmodel(a'(t) = b(t) * a(t) + c, b'(t) = d * a(t), y(t) = b(t)) + dds = @DDSmodel(a(t + 1) = b(t) * a(t) + c, b(t + 1) = d * a(t), y(t) = b(t)) push!( cases, @@ -63,7 +63,7 @@ if GROUP == "All" || GROUP == "Core" # ------------------- # Example 4 from https://doi.org/10.1016/j.automatica.2016.01.054 - dds = @ODEmodel(x'(t) = theta^3 * x(t), y(t) = x(t)) + dds = @DDSmodel(x(t + 1) = theta^3 * x(t), y(t) = x(t)) push!( cases,