From c4a57999e5d9181790e9b314434f32aebcccdd5d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 16:29:20 +0530 Subject: [PATCH] feat: add iteration limit for `fixpoint_sub` --- src/variable.jl | 12 ++++++++---- test/utils.jl | 8 ++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index b5230f07b..d7374b507 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -503,21 +503,25 @@ function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symboli end """ - fixpoint_sub(expr, dict; operator = Nothing) + fixpoint_sub(expr, dict; operator = Nothing, maxiters = 10000) Given a symbolic expression, equation or inequality `expr` perform the substitutions in `dict` recursively until the expression does not change. Substitutions that depend on one another will thus be recursively expanded. For example, `fixpoint_sub(x, Dict(x => y, y => 3))` will return `3`. The `operator` keyword can be -specified to prevent substitution of expressions inside operators of the given type. +specified to prevent substitution of expressions inside operators of the given type. The +`maxiters` keyword is used to limit the number of times the substitution can occur to avoid +infinite loops in cases where the substitutions in `dict` are circular +(e.g. `[x => y, y => x]`). See also: [`fast_substitute`](@ref). """ -function fixpoint_sub(x, dict; operator = Nothing) +function fixpoint_sub(x, dict; operator = Nothing, maxiters = 10000) y = fast_substitute(x, dict; operator) - while !isequal(x, y) + while !isequal(x, y) && maxiters > 0 y = x x = fast_substitute(y, dict; operator) + maxiters -= 1 end return x diff --git a/test/utils.jl b/test/utils.jl index 616dee642..6c876faba 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -38,3 +38,11 @@ end @test var_from_nested_derivative(p) == (p, 0) @test var_from_nested_derivative(D(p(x))) == (p(x), 1) end + +@testset "fixpoint_sub maxiters" begin + @variables x y + expr = fixpoint_sub(x, Dict(x => y, y => x)) + @test isequal(expr, x) + expr = fixpoint_sub(x, Dict(x => y, y => x); maxiters = 9) + @test isequal(expr, y) +end