Skip to content

Commit

Permalink
Merge pull request #3192 from AayushSabharwal/as/hc-cache-startsys
Browse files Browse the repository at this point in the history
feat: cache start system and solver in HomotopyContinuation interface
  • Loading branch information
ChrisRackauckas authored Nov 10, 2024
2 parents d25a060 + d284978 commit b52bce7
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 34 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
Expand Down Expand Up @@ -94,6 +95,7 @@ Distributions = "0.23, 0.24, 0.25"
DocStringExtensions = "0.7, 0.8, 0.9"
DomainSets = "0.6, 0.7"
DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
EnumX = "1.0.4"
ExprTools = "0.1.10"
Expronicon = "0.8"
FindFirstFunctions = "1"
Expand Down
149 changes: 126 additions & 23 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,104 @@ function contains_variable(x, wrt)
any(y -> occursin(y, x), wrt)
end

"""
Possible reasons why a term is not polynomial
"""
MTK.EnumX.@enumx NonPolynomialReason begin
NonIntegerExponent
ExponentContainsUnknowns
BaseNotPolynomial
UnrecognizedOperation
end

function display_reason(reason::NonPolynomialReason.T, sym)
if reason == NonPolynomialReason.NonIntegerExponent
pow = arguments(sym)[2]
"In $sym: Exponent $pow is not an integer"
elseif reason == NonPolynomialReason.ExponentContainsUnknowns
pow = arguments(sym)[2]
"In $sym: Exponent $pow contains unknowns of the system"
elseif reason == NonPolynomialReason.BaseNotPolynomial
base = arguments(sym)[1]
"In $sym: Base $base is not a polynomial in the unknowns"
elseif reason == NonPolynomialReason.UnrecognizedOperation
op = operation(sym)
"""
In $sym: Operation $op is not recognized. Allowed polynomial operations are \
`*, /, +, -, ^`.
"""
else
error("This should never happen. Please open an issue in ModelingToolkit.jl.")
end
end

mutable struct PolynomialData
non_polynomial_terms::Vector{BasicSymbolic}
reasons::Vector{NonPolynomialReason.T}
has_parametric_exponent::Bool
end

PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false)

struct NotPolynomialError <: Exception
eq::Equation
data::PolynomialData
end

function Base.showerror(io::IO, err::NotPolynomialError)
println(io,
"Equation $(err.eq) is not a polynomial in the unknowns for the following reasons:")
for (term, reason) in zip(err.data.non_polynomial_terms, err.data.reasons)
println(io, display_reason(reason, term))
end
end

function is_polynomial!(data, y, wrt)
process_polynomial!(data, y, wrt)
isempty(data.reasons)
end

"""
$(TYPEDSIGNATURES)
Check if `x` is polynomial with respect to the variables in `wrt`.
Return information about the polynmial `x` with respect to variables in `wrt`,
writing said information to `data`.
"""
function is_polynomial(x, wrt)
function process_polynomial!(data::PolynomialData, x, wrt)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return true
iscall(x) || return true
contains_variable(x, wrt) || return true
any(isequal(x), wrt) && return true

if operation(x) in (*, +, -, /)
return all(y -> is_polynomial(y, wrt), arguments(x))
return all(y -> is_polynomial!(data, y, wrt), arguments(x))
end
if operation(x) == (^)
b, p = arguments(x)
is_pow_integer = symtype(p) <: Integer
if !is_pow_integer
if symbolic_type(p) == NotSymbolic()
@warn "In $x: Exponent $p is not an integer"
else
@warn "In $x: Exponent $p is not an integer. Use `@parameters p::Integer` to declare integer parameters."
end
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.NonIntegerExponent)
end
if symbolic_type(p) != NotSymbolic()
data.has_parametric_exponent = true
end

exponent_has_unknowns = contains_variable(p, wrt)
if exponent_has_unknowns
@warn "In $x: Exponent $p cannot contain unknowns of the system."
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns)
end
base_polynomial = is_polynomial(b, wrt)
base_polynomial = is_polynomial!(data, b, wrt)
if !base_polynomial
@warn "In $x: Base is not a polynomial"
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.BaseNotPolynomial)
end
return base_polynomial && !exponent_has_unknowns && is_pow_integer
end
@warn "In $x: Unrecognized operation $(operation(x)). Allowed polynomial operations are `*, +, -, ^`"
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.UnrecognizedOperation)
return false
end

Expand Down Expand Up @@ -179,21 +241,39 @@ Create a `HomotopyContinuationProblem` from a `NonlinearSystem` with polynomial
The problem will be solved by HomotopyContinuation.jl. The resultant `NonlinearSolution`
will contain the polynomial root closest to the point specified by `u0map` (if real roots
exist for the system).
Keyword arguments:
- `eval_expression`: Whether to `eval` the generated functions or use a `RuntimeGeneratedFunction`.
- `eval_module`: The module to use for `eval`/`@RuntimeGeneratedFunction`
- `warn_parametric_exponent`: Whether to warn if the system contains a parametric
exponent preventing the homotopy from being cached.
All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`.
"""
function MTK.HomotopyContinuationProblem(
sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false,
eval_module = ModelingToolkit, kwargs...)
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
end

dvs = unknowns(sys)
eqs = equations(sys)
# we need to consider `full_equations` because observed also should be
# polynomials (if used in equations) and we don't know if observed is used
# in denominator.
# This is not the most efficient, and would be improved significantly with
# CSE/hashconsing.
eqs = full_equations(sys)

denoms = []
has_parametric_exponents = false
eqs2 = map(eqs) do eq
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
data = PolynomialData()
process_polynomial!(data, eq.lhs, dvs)
process_polynomial!(data, eq.rhs, dvs)
has_parametric_exponents |= data.has_parametric_exponent
if !isempty(data.non_polynomial_terms)
throw(NotPolynomialError(eq, data))
end
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)

Expand All @@ -212,6 +292,9 @@ function MTK.HomotopyContinuationProblem(
end

sys2 = MTK.@set sys.eqs = eqs2
# remove observed equations to avoid adding them in codegen
MTK.@set! sys2.observed = Equation[]
MTK.@set! sys2.substitutions = nothing

nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
jac = true, eval_expression, eval_module)
Expand All @@ -223,29 +306,49 @@ function MTK.HomotopyContinuationProblem(

obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)

return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn)
if has_parametric_exponents
if warn_parametric_exponent
@warn """
The system has parametric exponents, preventing caching of the homotopy. \
This will cause `solve` to be slower. Pass `warn_parametric_exponent \
= false` to turn off this warning
"""
end
solver_and_starts = nothing
else
solver_and_starts = HomotopyContinuation.solver_startsolutions(mtkhsys; kwargs...)
end
return MTK.HomotopyContinuationProblem(
u0, mtkhsys, denominator, sys, obsfn, solver_and_starts)
end

"""
$(TYPEDSIGNATURES)
Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are
forwarded to `HomotopyContinuation.solve`. The original solution as returned by
uses `HomotopyContinuation.jl`. The original solution as returned by
`HomotopyContinuation.jl` will be available in the `.original` field of the returned
`NonlinearSolution`.
All keyword arguments have their default values in HomotopyContinuation.jl, except
`show_progress` which defaults to `false`.
All keyword arguments except the ones listed below are forwarded to
`HomotopyContinuation.solve`. Note that the solver and start solutions are precomputed,
and only keyword arguments related to the solve process are valid. All keyword
arguments have their default values in HomotopyContinuation.jl, except `show_progress`
which defaults to `false`.
Extra keyword arguments:
- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
the denominator to be below `denominator_abstol` will be discarded.
"""
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
alg = nothing; show_progress = false, denominator_abstol = 1e-7, kwargs...)
sol = HomotopyContinuation.solve(
prob.homotopy_continuation_system; show_progress, kwargs...)
if prob.solver_and_starts === nothing
sol = HomotopyContinuation.solve(
prob.homotopy_continuation_system; show_progress, kwargs...)
else
solver, starts = prob.solver_and_starts
sol = HomotopyContinuation.solve(solver, starts; show_progress, kwargs...)
end
realsols = HomotopyContinuation.results(sol; only_real = true)
if isempty(realsols)
u = state_values(prob)
Expand Down
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ using RecursiveArrayTools
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
import CommonSolve
import EnumX

using RuntimeGeneratedFunctions
using RuntimeGeneratedFunctions: drop_expr
Expand Down
7 changes: 6 additions & 1 deletion src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
create and solve.
"""
struct HomotopyContinuationProblem{uType, H, D, O} <:
struct HomotopyContinuationProblem{uType, H, D, O, SS} <:
SciMLBase.AbstractNonlinearProblem{uType, true}
"""
The initial values of states in the system. If there are multiple real roots of
Expand All @@ -716,6 +716,11 @@ struct HomotopyContinuationProblem{uType, H, D, O} <:
A function which generates and returns observed expressions for the given system.
"""
obsfn::O
"""
The HomotopyContinuation.jl solver and start system, obtained through
`HomotopyContinuation.solver_startsystems`.
"""
solver_and_starts::SS
end

function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...)
Expand Down
38 changes: 29 additions & 9 deletions test/extensions/homotopy_continuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,32 @@ end
@test sol.retcode == ReturnCode.ConvergenceFailure
end

@testset "Parametric exponent" begin
@testset "Parametric exponents" begin
@variables x = 1.0
@parameters n::Integer = 4
@mtkbuild sys = NonlinearSystem([x^n + x^2 - 1 ~ 0])
prob = HomotopyContinuationProblem(sys, [])
prob = @test_warn ["parametric", "exponent"] HomotopyContinuationProblem(sys, [])
@test prob.solver_and_starts === nothing
@test_nowarn HomotopyContinuationProblem(sys, []; warn_parametric_exponent = false)
sol = solve(prob; threading = false)
@test SciMLBase.successful_retcode(sol)
end

@testset "Polynomial check and warnings" begin
@variables x = 1.0
@parameters n = 4
@mtkbuild sys = NonlinearSystem([x^n + x^2 - 1 ~ 0])
@test_warn ["Exponent", "not an integer", "@parameters"] @test_throws "not a polynomial" HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([x^1.5 + x^2 - 1 ~ 0])
@test_warn ["Exponent", "not an integer"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@test_throws ["Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
@test_warn ["Exponent", "unknowns"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@test_throws ["Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
@test_warn ["Unrecognized", "sin"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
sys, [])

@variables y = 2.0
@mtkbuild sys = NonlinearSystem([x^2 + y^2 + 2 ~ 0, y ~ sin(x)])
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
sys, [])
end

Expand Down Expand Up @@ -131,4 +134,21 @@ end
end
end
@test prob.denominator([2.0, 4.0], p)[1] <= 1e-8

@testset "Rational function in observed" begin
@variables x=1 y=1
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
@test_nowarn solve(prob; threading = false)
end
end

@testset "Non-polynomial observed not used in equations" begin
@variables x=1 y
@mtkbuild sys = NonlinearSystem([x^2 - 2 ~ 0, y ~ sin(x)])
prob = HomotopyContinuationProblem(sys, [])
sol = @test_nowarn solve(prob; threading = false)
@test sol[x] 2.0
@test sol[y] sin(2.0)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ end

if GROUP == "All" || GROUP == "Extensions"
activate_extensions_env()
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
@safetestset "HomotopyContinuation Extension Test" include("extensions/homotopy_continuation.jl")
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
@safetestset "LabelledArrays Test" include("labelledarrays.jl")
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
end
end

0 comments on commit b52bce7

Please sign in to comment.