Skip to content

Commit

Permalink
Merge pull request #262 from SciML/renaming_p
Browse files Browse the repository at this point in the history
renaming `p` to `prob_threschold`
  • Loading branch information
pogudingleb authored Jan 8, 2024
2 parents d97a160 + 8c30f65 commit a763206
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 107 deletions.
4 changes: 2 additions & 2 deletions docs/src/tutorials/discrete_time.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ The `assess_local_identifiability` function has three important keyword argument
assess_local_identifiability(sir; measured_quantities = [I], funcs_to_check = [β * S])
```

- `p` is the probability of correctness (default value `0.99`, i.e., 99%). The underlying algorithm is a Monte-Carlo algorithm, so in
principle it may produce incorrect result but the probability of correctness of the returned result is guaranteed to be at least `p`
- `prob_threshold` is the probability of correctness (default value `0.99`, i.e., 99%). The underlying algorithm is a Monte-Carlo algorithm, so in
principle it may produce incorrect result but the probability of correctness of the returned result is guaranteed to be at least `prob_threshold`
(in fact, the employed bounds are quite conservative, so in practice incorrect result is almost never produced).

- `known_ic` is a list of the states for which initial conditions are known. In this case, the identifiability results will be valid not
Expand Down
6 changes: 3 additions & 3 deletions docs/src/tutorials/identifiability.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ Function `assess_local_identifiability` has several optional parameters
- `funcs_to_check` a list of specific functions of parameters and states to check identifiability for (see an example below).
If not provided, the identifiability is assessed for all parameters and states.

- `p` (default $0.99$) is the probability of correctness. The algorithm can, in theory, produce wrong result, but the probability that it is correct
is guaranteed to be at least `p`. However, the probability bounds we use are quite conservative, so the actual probability of correctness is
- `prob_threshold` (default $0.99$, i.e. 99%) is the probability of correctness. The algorithm can, in theory, produce wrong result, but the probability that it is correct
is guaranteed to be at least `prob_threshold`. However, the probability bounds we use are quite conservative, so the actual probability of correctness is
likely to be much higher.
- `type` (default `:SE`). By default, the algorithm checks the standard single-experiment identifiability. If one sets `type = :ME`, then the algorithm
checks multi-experiment identifiability, that is, identifiability from several experiments with independent initial conditions (the algorithm from [^2] is used).
Expand Down Expand Up @@ -105,7 +105,7 @@ Similarly to `assess_local_identifiability`, this function has optional paramete
more involved than for the parameters, so one may want to call the function with `funcs_to_check = ode.parameters` if the
call `assess_identifiability(ode)` takes too long.

- `p` (default $0.99$) is the probability of correctness. Same story as above: the probability estimates are very conservative, so the actual
- `prob_threshold` (default $0.99$, i.e. 99%) is the probability of correctness. Same story as above: the probability estimates are very conservative, so the actual
error probability is much lower than 1%.
Also, currently, the probability of correctness does not include the probability of correctness of the modular reconstruction for Groebner bases.
This probability is ensured by an additional check modulo a large prime, and can be neglected for practical purposes.
Expand Down
55 changes: 34 additions & 21 deletions src/RationalFunctionFields/RationalFunctionField.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ end
# ------------------------------------------------------------------------------

"""
field_contains(field, ratfuncs, p)
field_contains(field, ratfuncs, prob_threshold)
Checks whether given rational function field `field` contains given rational
functions `ratfuncs` (represented as a list of lists). The result is correct with
probability at least `p`
probability at least `prob_threshold`
Inputs:
- `field` - a rational function field
- `ratfuncs` - a list of lists of polynomials. Each of the lists, say, `[f1, ..., fn]`,
defines generators `f2/f1, ..., fn/f1`.
- `p` real number from (0, 1)
- `prob_threshold` real number from (0, 1)
Output:
- a list `L[i]` of bools of length `length(rat_funcs)` such that `L[i]` is true iff
Expand All @@ -95,7 +95,7 @@ Output:
@timeit _to function field_contains(
field::RationalFunctionField{T},
ratfuncs::Vector{Vector{T}},
p,
prob_threshold,
) where {T}
if isempty(ratfuncs)
return Bool[]
Expand Down Expand Up @@ -133,7 +133,7 @@ Output:
3 *
BigInt(degree)^(length(total_vars) + 3) *
(length(ratfuncs)) *
ceil(1 / (1 - p)),
ceil(1 / (1 - prob_threshold)),
)
@debug "\tSampling from $(-sampling_bound) to $(sampling_bound)"

Expand All @@ -153,24 +153,36 @@ end
function field_contains(
field::RationalFunctionField{T},
ratfuncs::Vector{Generic.Frac{T}},
p,
prob_threshold,
) where {T}
return field_contains(field, fractions_to_dennums(ratfuncs), p)
return field_contains(field, fractions_to_dennums(ratfuncs), prob_threshold)
end

function field_contains(field::RationalFunctionField{T}, polys::Vector{T}, p) where {T}
function field_contains(
field::RationalFunctionField{T},
polys::Vector{T},
prob_threshold,
) where {T}
id = one(parent(first(polys)))
return field_contains(field, [[id, p] for p in polys], p)
return field_contains(field, [[id, p] for p in polys], prob_threshold)
end

# ------------------------------------------------------------------------------

function issubfield(F::RationalFunctionField{T}, E::RationalFunctionField{T}, p) where {T}
return all(field_contains(E, F.dennums, p))
function issubfield(
F::RationalFunctionField{T},
E::RationalFunctionField{T},
prob_threshold,
) where {T}
return all(field_contains(E, F.dennums, prob_threshold))
end

function fields_equal(F::RationalFunctionField{T}, E::RationalFunctionField{T}, p) where {T}
new_p = 1 - (1 - p) / 2
function fields_equal(
F::RationalFunctionField{T},
E::RationalFunctionField{T},
prob_threshold,
) where {T}
new_p = 1 - (1 - prob_threshold) / 2
return issubfield(F, E, new_p) && issubfield(E, F, new_p)
end

Expand Down Expand Up @@ -475,14 +487,14 @@ function monomial_generators_up_to_degree(
end

"""
simplified_generating_set(rff; p = 0.99, seed = 42)
simplified_generating_set(rff; prob_threshold = 0.99, seed = 42)
Returns a simplified set of generators for `rff`.
Result is correct (in Monte-Carlo sense) with probability at least `p`.
Result is correct (in the Monte-Carlo sense) with probability at least `prob_threshold`.
"""
@timeit _to function simplified_generating_set(
rff::RationalFunctionField;
p = 0.99,
prob_threshold = 0.99,
seed = 42,
simplify = :standard,
check_variables = false, # almost always slows down and thus turned off
Expand All @@ -500,8 +512,8 @@ Result is correct (in Monte-Carlo sense) with probability at least `p`.
# Checking membership of particular variables and adding them to the field
if check_variables
vars = gens(poly_ring(rff))
containment = field_contains(rff, vars, (1.0 + p) / 2)
p = (1.0 + p) / 2
containment = field_contains(rff, vars, (1.0 + prob_threshold) / 2)
prob_threshold = (1.0 + prob_threshold) / 2
if all(containment)
return [v // one(poly_ring(rff)) for v in vars]
end
Expand Down Expand Up @@ -557,14 +569,15 @@ Final cleaning and simplification of generators.
Out of $(length(new_fracs)) fractions $(length(new_fracs_unique)) are syntactically unique."""
runtime =
@elapsed new_fracs = beautifuly_generators(RationalFunctionField(new_fracs_unique))
@debug "Checking inclusion with probability $p"
runtime = @elapsed result = issubfield(rff, RationalFunctionField(new_fracs), p)
@debug "Checking inclusion with probability $prob_threshold"
runtime =
@elapsed result = issubfield(rff, RationalFunctionField(new_fracs), prob_threshold)
_runtime_logger[:id_inclusion_check] = runtime
if !result
@warn "Field membership check failed. Error will follow."
throw("The new subfield generators are not correct.")
end
@info "Inclusion checked with probability $p in $(_runtime_logger[:id_inclusion_check]) seconds"
@info "Inclusion checked with probability $prob_threshold in $(_runtime_logger[:id_inclusion_check]) seconds"
@debug "Out of $(length(rff.mqs.nums_qq)) initial generators there are $(length(new_fracs)) indepdendent"
ranking = generating_set_rank(new_fracs)
_runtime_logger[:id_ranking] = ranking
Expand Down
40 changes: 24 additions & 16 deletions src/StructuralIdentifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,40 +82,44 @@ function __init__()
end

"""
assess_identifiability(ode; funcs_to_check = [], p=0.99, loglevel=Logging.Info)
assess_identifiability(ode; funcs_to_check = [], prob_threshold=0.99, loglevel=Logging.Info)
Input:
- `ode` - the ODE model
- `funcs_to_check` - list of functions to check identifiability for; if empty, all parameters
and states are taken
- `p` - probability of correctness.
- `prob_threshold` - probability of correctness.
- `loglevel` - the minimal level of log messages to display (`Logging.Info` by default)
Assesses identifiability of a given ODE model. The result is guaranteed to be correct with the probability
at least `p`.
at least `prob_threshold`.
The function returns an (ordered) dictionary from the functions to check to their identifiability properties
(one of `:nonidentifiable`, `:locally`, `:globally`).
"""
function assess_identifiability(
ode::ODE{P};
funcs_to_check = Vector(),
p::Float64 = 0.99,
prob_threshold::Float64 = 0.99,
loglevel = Logging.Info,
) where {P <: MPolyElem{fmpq}}
restart_logging(loglevel = loglevel)
reset_timings()
with_logger(_si_logger[]) do
return _assess_identifiability(ode, funcs_to_check = funcs_to_check, p = p)
return _assess_identifiability(
ode,
funcs_to_check = funcs_to_check,
prob_threshold = prob_threshold,
)
end
end

function _assess_identifiability(
ode::ODE{P};
funcs_to_check = Vector(),
p::Float64 = 0.99,
prob_threshold::Float64 = 0.99,
) where {P <: MPolyElem{fmpq}}
p_glob = 1 - (1 - p) * 0.9
p_loc = 1 - (1 - p) * 0.1
p_glob = 1 - (1 - prob_threshold) * 0.9
p_loc = 1 - (1 - prob_threshold) * 0.1

if isempty(funcs_to_check)
funcs_to_check = vcat(ode.x_vars, ode.parameters)
Expand All @@ -126,7 +130,7 @@ function _assess_identifiability(
runtime = @elapsed local_result = _assess_local_identifiability(
ode,
funcs_to_check = funcs_to_check,
p = p_loc,
prob_threshold = p_loc,
type = :SE,
trbasis = trbasis,
)
Expand Down Expand Up @@ -167,23 +171,23 @@ function _assess_identifiability(
end

"""
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], p = 0.99, loglevel=Logging.Info)
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], prob_threshold = 0.99, loglevel=Logging.Info)
Input:
- `ode` - the ModelingToolkit.ODESystem object that defines the model
- `measured_quantities` - the output functions of the model
- `funcs_to_check` - functions of parameters for which to check the identifiability
- `p` - probability of correctness.
- `prob_threshold` - probability of correctness.
- `loglevel` - the minimal level of log messages to display (`Logging.Info` by default)
Assesses identifiability (both local and global) of a given ODE model (parameters detected automatically). The result is guaranteed to be correct with the probability
at least `p`.
at least `prob_threshold`.
"""
function assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
funcs_to_check = [],
p = 0.99,
prob_threshold = 0.99,
loglevel = Logging.Info,
)
restart_logging(loglevel = loglevel)
Expand All @@ -192,7 +196,7 @@ function assess_identifiability(
ode,
measured_quantities = measured_quantities,
funcs_to_check = funcs_to_check,
p = p,
prob_threshold = prob_threshold,
)
end
end
Expand All @@ -201,7 +205,7 @@ function _assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
funcs_to_check = [],
p = 0.99,
prob_threshold = 0.99,
)
if isempty(measured_quantities)
measured_quantities = get_measured_quantities(ode)
Expand All @@ -214,7 +218,11 @@ function _assess_identifiability(
end
funcs_to_check_ = [eval_at_nemo(each, conversion) for each in funcs_to_check]

result = _assess_identifiability(ode, funcs_to_check = funcs_to_check_, p = p)
result = _assess_identifiability(
ode,
funcs_to_check = funcs_to_check_,
prob_threshold = prob_threshold,
)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
return out_dict
Expand Down
28 changes: 16 additions & 12 deletions src/discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ function _degree_with_common_denom(polys)
end

"""
_assess_local_identifiability_discrete_aux(dds::ODE{P}, funcs_to_check::Array{<: Any, 1}, known_ic, p::Float64=0.99) where P <: MPolyElem{Nemo.fmpq}
_assess_local_identifiability_discrete_aux(dds::ODE{P}, funcs_to_check::Array{<: Any, 1}, known_ic, prob_threshold::Float64=0.99) where P <: MPolyElem{Nemo.fmpq}
Checks the local identifiability/observability of the functions in `funcs_to_check` treating `dds` as a discrete-time system with **shift**
instead of derivative in the right-hand side.
The result is correct with probability at least `p`.
The result is correct with probability at least `prob_threshold`.
`known_ic` can take one of the following
* `:none` - no initial conditions are assumed to be known
* `:all` - all initial conditions are assumed to be known
Expand All @@ -192,7 +192,7 @@ function _assess_local_identifiability_discrete_aux(
dds::ODE{P},
funcs_to_check::Array{<:Any, 1},
known_ic = :none,
p::Float64 = 0.99,
prob_threshold::Float64 = 0.99,
) where {P <: MPolyElem{Nemo.fmpq}}
bring = base_ring(dds.poly_ring)

Expand Down Expand Up @@ -222,7 +222,7 @@ function _assess_local_identifiability_discrete_aux(
else
Jac_degree += 2 * deg_y * prec
end
D = Int(ceil(Jac_degree * length(funcs_to_check) / (1 - p)))
D = Int(ceil(Jac_degree * length(funcs_to_check) / (1 - prob_threshold)))
@debug "Sampling range $D"

# Parameter values are the same across all the replicas
Expand Down Expand Up @@ -293,26 +293,26 @@ end
measured_quantities=Array{ModelingToolkit.Equation}[],
funcs_to_check=Array{}[],
known_ic=Array{}[],
p::Float64=0.99)
prob_threshold::Float64=0.99)
Input:
- `dds` - the DiscreteSystem object from ModelingToolkit (with **difference** operator in the right-hand side)
- `measured_quantities` - the measurable outputs of the model
- `funcs_to_check` - functions of parameters for which to check identifiability (all parameters and states if not specified)
- `known_ic` - functions (of states and parameter) whose initial conditions are assumed to be known
- `p` - probability of correctness
- `prob_threshold` - probability of correctness
Output:
- the result is an (ordered) dictionary from each function to to boolean;
The result is correct with probability at least `p`.
The result is correct with probability at least `prob_threshold`.
"""
function assess_local_identifiability(
dds::ModelingToolkit.DiscreteSystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
funcs_to_check = Array{}[],
known_ic = Array{}[],
p::Float64 = 0.99,
prob_threshold::Float64 = 0.99,
loglevel = Logging.Info,
)
restart_logging(loglevel = loglevel)
Expand All @@ -322,7 +322,7 @@ function assess_local_identifiability(
measured_quantities = measured_quantities,
funcs_to_check = funcs_to_check,
known_ic = known_ic,
p = p,
prob_threshold = prob_threshold,
)
end
end
Expand All @@ -332,7 +332,7 @@ function _assess_local_identifiability(
measured_quantities = Array{ModelingToolkit.Equation}[],
funcs_to_check = Array{}[],
known_ic = Array{}[],
p::Float64 = 0.99,
prob_threshold::Float64 = 0.99,
)
if length(measured_quantities) == 0
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(dds))
Expand Down Expand Up @@ -373,8 +373,12 @@ function _assess_local_identifiability(
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
known_ic_ = [eval_at_nemo(x, conversion) for x in known_ic]

result =
_assess_local_identifiability_discrete_aux(dds_aux, funcs_to_check_, known_ic_, p)
result = _assess_local_identifiability_discrete_aux(
dds_aux,
funcs_to_check_,
known_ic_,
prob_threshold,
)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = OrderedDict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
if length(known_ic) > 0
Expand Down
Loading

0 comments on commit a763206

Please sign in to comment.