Skip to content

Commit

Permalink
feat: add iteration limit for fixpoint_sub
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 24, 2024
1 parent 10a61ee commit c4a5799
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,21 +503,25 @@ function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symboli
end

"""
fixpoint_sub(expr, dict; operator = Nothing)
fixpoint_sub(expr, dict; operator = Nothing, maxiters = 10000)
Given a symbolic expression, equation or inequality `expr` perform the substitutions in
`dict` recursively until the expression does not change. Substitutions that depend on one
another will thus be recursively expanded. For example,
`fixpoint_sub(x, Dict(x => y, y => 3))` will return `3`. The `operator` keyword can be
specified to prevent substitution of expressions inside operators of the given type.
specified to prevent substitution of expressions inside operators of the given type. The
`maxiters` keyword is used to limit the number of times the substitution can occur to avoid
infinite loops in cases where the substitutions in `dict` are circular
(e.g. `[x => y, y => x]`).
See also: [`fast_substitute`](@ref).
"""
function fixpoint_sub(x, dict; operator = Nothing)
function fixpoint_sub(x, dict; operator = Nothing, maxiters = 10000)
y = fast_substitute(x, dict; operator)
while !isequal(x, y)
while !isequal(x, y) && maxiters > 0
y = x
x = fast_substitute(y, dict; operator)
maxiters -= 1
end

return x
Expand Down
8 changes: 8 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ end
@test var_from_nested_derivative(p) == (p, 0)
@test var_from_nested_derivative(D(p(x))) == (p(x), 1)
end

@testset "fixpoint_sub maxiters" begin
@variables x y
expr = fixpoint_sub(x, Dict(x => y, y => x))
@test isequal(expr, x)
expr = fixpoint_sub(x, Dict(x => y, y => x); maxiters = 9)
@test isequal(expr, y)
end

0 comments on commit c4a5799

Please sign in to comment.