Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Identifiability with known generic initial conditions #249

Merged
merged 13 commits into from
Feb 5, 2024
53 changes: 53 additions & 0 deletions src/RationalFunctionFields/RationalFunctionField.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ function poly_ring(F::RationalFunctionField)
return parent(first(first(F.dennums)))
end

function generators(F::RationalFunctionField)
return vcat([[f[i] // f[1] for i in 2:length(f)] for f in F.dennums]...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe dennums_to_fractions

end

# ------------------------------------------------------------------------------

"""
Expand Down Expand Up @@ -165,6 +169,55 @@ end

# ------------------------------------------------------------------------------

"""
check_algebraicity(field, ratfuncs, p)

Checks whether given rational function `ratfuncs` are algebraic over the field `field`
The result is correct with probability at least `p`

Inputs:
- `field` - a rational function field
- `ratfuncs` - a list of lists of rational functions.
- `p` real number from (0, 1)

Output:
- a list `L[i]` of bools of length `length(rat_funcs)` such that `L[i]` is true iff
the i-th function is algebraic over the `field`
"""
function check_algebraicity(field, ratfuncs, p)
fgens = generators(field)
base_vars = gens(poly_ring(field))
maxdeg = maximum([
max(total_degree(numerator(f)), total_degree(denominator(f))) for
f in vcat(ratfuncs, fgens)
])
# degree of the polynomial whose nonvanishing will be needed for correct result
D = Int(ceil(2 * maxdeg * (length(fgens) + 1)^3 * length(ratfuncs) / (1 - p)))
eval_point = [Nemo.QQ(rand(1:D)) for x in base_vars]

# Filling the jacobain for generators
S = MatrixSpace(Nemo.QQ, length(base_vars), length(fgens) + 1)
J = zero(S)
for (i, f) in enumerate(fgens)
for (j, x) in enumerate(base_vars)
J[j, i] = evaluate(derivative(f, x), eval_point)
end
end
rank = LinearAlgebra.rank(J)

result = Bool[]
for f in ratfuncs
f = parent_ring_change(f, poly_ring(field))
for (j, x) in enumerate(base_vars)
J[j, end] = evaluate(derivative(f, x), eval_point)
end
push!(result, LinearAlgebra.rank(J) == rank)
end
return result
end

# ------------------------------------------------------------------------------

function issubfield(F::RationalFunctionField{T}, E::RationalFunctionField{T}, p) where {T}
return all(field_contains(E, F.dennums, p))
end
Expand Down
1 change: 1 addition & 0 deletions src/StructuralIdentifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ include("lincomp.jl")
include("pb_representation.jl")
include("submodels.jl")
include("discrete.jl")
include("known_ic.jl")

function __init__()
_si_logger[] = @static if VERSION >= v"1.7.0"
Expand Down
158 changes: 158 additions & 0 deletions src/known_ic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
find_identifiable_functions_kic(ode::ODE, known_ic; options...)

Finds all functions of parameters/initial conditions that are identifiable in the given ODE
system under assumptions that the initial conditions for functions in the list
`known_ic` are known and generic.

## Options

This functions takes the following optional arguments:
- `simplify`: The extent to which the output functions are simplified. Stronger
simplification may require more time. Possible options are:
- `:standard`: Default simplification.
- `:weak`: Weak simplification. This option is the fastest, but the output
functions can be quite complex.
- `:strong`: Strong simplification. This option is the slowest, but the output
functions are nice and simple.
- `:absent`: No simplification.
- `p`: A float in the range from 0 to 1, the probability of correctness. Default
is `0.99`.
- `seed`: The rng seed. Default value is `42`.
- `loglevel` - the minimal level of log messages to display (`Logging.Info` by default)

**This is experimental functionality**

```

"""
function find_identifiable_functions_kic(
ode::ODE{T},
known_ic::Vector{<:Union{T, Generic.Frac{T}}};
p::Float64 = 0.99,
seed = 42,
simplify = :standard,
rational_interpolator = :VanDerHoevenLecerf,
loglevel = Logging.Info,
) where {T <: MPolyElem{fmpq}}
restart_logging(loglevel = loglevel)
reset_timings()
with_logger(_si_logger[]) do
return _find_identifiable_functions_kic(
ode,
known_ic,
p = p,
seed = seed,
simplify = simplify,
rational_interpolator = rational_interpolator,
)
end
end

function _find_identifiable_functions_kic(
ode::ODE{T},
known_ic::Vector{<:Union{T, Generic.Frac{T}}};
p::Float64 = 0.99,
seed = 42,
simplify = :standard,
rational_interpolator = :VanDerHoevenLecerf,
) where {T <: MPolyElem{fmpq}}
Random.seed!(seed)
@assert simplify in (:standard, :weak, :strong, :absent)
half_p = 0.5 + p / 2
runtime_start = time_ns()
id_funcs_general = find_identifiable_functions(
ode,
p = half_p,
with_states = true,
simplify = simplify,
rational_interpolator = rational_interpolator,
seed = seed,
)

id_funcs = simplified_generating_set(
RationalFunctionField(
vcat(id_funcs_general, [f // one(parent(ode)) for f in known_ic]),
),
p = half_p,
seed = seed,
simplify = simplify,
rational_interpolator = rational_interpolator,
)

@info "The search for identifiable functions with known initial conditions concluded in $((time_ns() - runtime_start) / 1e9) seconds"

return replace_with_ic(ode, id_funcs)
end

"""
assess_identifiability_kic(ode, known_ic; funcs_to_check = [], p=0.99, loglevel=Logging.Info)

Input:
- `ode` - the ODE model
- `known_ic` - a list of functions for which initial conditions are assumed to be known and generic
- `funcs_to_check` - list of functions to check identifiability for; if empty, all parameters
and states are taken
- `p` - probability of correctness.
- `loglevel` - the minimal level of log messages to display (`Logging.Info` by default)

Assesses identifiability of parameters and initial conditions of a given ODE model.
The result is guaranteed to be correct with the probability at least `p`.
The function returns an (ordered) dictionary from the functions to check to their identifiability properties
(one of `:nonidentifiable`, `:locally`, `:globally`).
"""
function assess_identifiability_kic(
ode::ODE{P},
known_ic::Vector{<:Union{P, Generic.Frac{P}}};
funcs_to_check = Vector(),
p::Float64 = 0.99,
loglevel = Logging.Info,
) where {P <: MPolyElem{fmpq}}
restart_logging(loglevel = loglevel)
reset_timings()
with_logger(_si_logger[]) do
return _assess_identifiability_kic(
ode,
known_ic,
p = p,
funcs_to_check = funcs_to_check,
)
end
end

function _assess_identifiability_kic(
ode::ODE{P},
known_ic::Vector{<:Union{P, Generic.Frac{P}}};
funcs_to_check = Vector(),
p::Float64 = 0.99,
) where {P <: MPolyElem{fmpq}}
runtime_start = time_ns()
if length(funcs_to_check) == 0
funcs_to_check = vcat(ode.x_vars, ode.parameters)
end
half_p = 0.5 + p / 2
id_funcs = _find_identifiable_functions_kic(ode, known_ic, p = half_p)
funcs_to_check = replace_with_ic(ode, funcs_to_check)
result = OrderedDict(f => :globally for f in funcs_to_check)

half_p = 0.5 + p / 2
local_result = check_algebraicity(
RationalFunctionField([[denominator(f), numerator(f)] for f in id_funcs]),
[f // one(parent(f)) for f in funcs_to_check],
half_p,
)
global_result = field_contains(
RationalFunctionField([[denominator(f), numerator(f)] for f in id_funcs]),
[f // one(parent(f)) for f in funcs_to_check],
half_p,
)
for (i, f) in enumerate(funcs_to_check)
if !local_result[i]
result[f] = :nonidentifiable
elseif !global_result[i]
result[f] = :locally
end
end
@info "Assessing identifiability with known initial conditions concluded in $((time_ns() - runtime_start) / 1e9) seconds"
return result
end
1 change: 1 addition & 0 deletions src/local_identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ function _assess_local_identifiability(
p::Float64 = 0.99,
type = :SE,
trbasis = nothing,
known_ic::Array{<:Any, 1} = Array{Any, 1}(),
) where {P <: MPolyElem{Nemo.fmpq}}
if isempty(funcs_to_check)
funcs_to_check = ode.parameters
Expand Down
24 changes: 24 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,27 @@
)
end
end

# -----------------------------------------------------------------------------

"""
replace_with_ic(ode::ODE, funcs)

Takes an ode and a list of functions in the states and parameters and makes a change of variable
names `x(t) -> x(0)`. Function is used to prepare the output for the case of known initial conditions
"""
function replace_with_ic(ode, funcs)
varnames = [(var_to_str(p), var_to_str(p)) for p in ode.parameters]
for x in ode.x_vars
s = var_to_str(x)
if endswith(s, "(t)")
push!(varnames, (s, s[1:(end - 3)] * "(0)"))
else
push!(varnames, (s, s * "(0)"))

Check warning on line 598 in src/util.jl

View check run for this annotation

Codecov / codecov/patch

src/util.jl#L598

Added line #L598 was not covered by tests
end
end
R0, vars0 = PolynomialRing(base_ring(ode.poly_ring), [p[2] for p in varnames])
eval_dict =
Dict(str_to_var(p[1], ode.poly_ring) => str_to_var(p[2], R0) for p in varnames)
return [eval_at_dict(f, eval_dict) for f in funcs]
end
26 changes: 26 additions & 0 deletions test/check_algebraicity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
cases = []

R, (x, y, z) = PolynomialRing(QQ, ["x", "y", "z"])
push!(
cases,
Dict(
:F => RationalFunctionField([[one(R), x + y, x * y]]),
:funcs => [x // one(R), z // one(R), x^3 - y^3 // one(R), x + z // one(R)],
:correct => [true, false, true, false],
),
)

push!(
cases,
Dict(
:F => RationalFunctionField([[x, y], [y, z]]),
:funcs => [x // z, (x + y) // z, x // one(R), y // one(R), z // one(R)],
:correct => [true, true, false, false, false],
),
)

@testset "Algebraicity over a field" begin
for case in cases
@test check_algebraicity(case[:F], case[:funcs], 0.99) == case[:correct]
end
end
67 changes: 67 additions & 0 deletions test/known_ic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
cases = []

ode = @ODEmodel(
x1'(t) = a * x1(t) - b * x1(t) * x2(t),
x2'(t) = -c * x2(t) + d * x1(t) * x2(t),
y(t) = x1(t)
)

push!(
cases,
Dict(
:ode => ode,
:known => [x2],
:to_check => [],
:correct_funcs => [a, b, c, d, x1, x2],
:correct_ident => OrderedDict(x => :globally for x in [x1, x2, a, b, c, d]),
),
)

ode = @ODEmodel(x1'(t) = a + x2(t) + x3(t), x2'(t) = b^2 + c, x3'(t) = -c, y(t) = x1(t))

push!(
cases,
Dict(
:ode => ode,
:known => [x2, x3],
:to_check => [],
:correct_funcs => [a, b^2, x1, x2, x3],
:correct_ident => OrderedDict(
x1 => :globally,
x2 => :globally,
x3 => :globally,
a => :globally,
b => :locally,
c => :nonidentifiable,
),
),
)

push!(
cases,
Dict(
:ode => ode,
:known => [x2, x3],
:to_check => [b^2, x2 * c],
:correct_funcs => [a, b^2, x1, x2, x3],
:correct_ident => OrderedDict(b^2 => :globally, x2 * c => :nonidentifiable),
),
)

@testset "Identifiable functions with known generic initial conditions" begin
for case in cases
ode = case[:ode]
known = case[:known]

result_funcs = find_identifiable_functions_kic(ode, known)
correct_funcs =
replace_with_ic(ode, [f // one(parent(ode)) for f in case[:correct_funcs]])
@test Set(result_funcs) == Set(correct_funcs)

result_ident =
assess_identifiability_kic(ode, known, funcs_to_check = case[:to_check])
@test OrderedDict(
replace_with_ic(ode, [k])[1] => v for (k, v) in case[:correct_ident]
) == result_ident
end
end
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using StructuralIdentifiability:
check_identifiability,
check_primality_zerodim,
check_primality,
check_algebraicity,
det_minor_expansion,
ExpVectTrie,
get_max_below,
Expand Down Expand Up @@ -58,7 +59,10 @@ using StructuralIdentifiability:
extract_coefficients_ratfunc,
lie_derivative,
states_generators,
RationalFunctionField
RationalFunctionField,
find_identifiable_functions_kic,
assess_identifiability_kic,
replace_with_ic

function random_ps(ps_ring, range = 1000)
result = zero(ps_ring)
Expand Down