diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 598a91509..f78bc9931 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -55,6 +55,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines include("wrapper-types.jl") include("num.jl") + include("rewrite-helpers.jl") include("complex.jl") diff --git a/src/rewrite-helpers.jl b/src/rewrite-helpers.jl index cfb917d50..2afc6d4b0 100644 --- a/src/rewrite-helpers.jl +++ b/src/rewrite-helpers.jl @@ -1,35 +1,28 @@ """ - replace(expr::Symbolic, rules...) -Walk the expression and replace subexpressions according to `rules`. `rules` +replacenode(expr::Symbolic, rules...) +Walk the expression and replacenode subexpressions according to `rules`. `rules` could be rules constructed with `@rule`, a function, or a pair where the -left hand side is matched with equality (using `isequal`) and is replaced by the right hand side. +left hand side is matched with equality (using `isequal`) and is replacenoded by the right hand side. Rules will be applied left-to-right simultaneously, so only one pattern will be applied to any subexpression, and the patterns will only be applied to the input text, -not the replacements. +not the replacenodements. Set `fixpoint = true` to repeatedly apply rules until no change to the expression remains to be made. """ -function Base.replace(expr::Num, r::Pair, rules::Pair...) - _replace(unwrap(expr), r, rules...) +function replacenode(expr::Num, r::Pair, rules::Pair...; fixpoint = false) + _replacenode(unwrap(expr), r, rules...) end - # Fix ambiguity -function Base.replace(expr::Num, rules...) - _replace(unwrap(expr), rules...) -end - -function Base.replace(expr::Symbolic, rules...) - _replace(unwrap(expr), rules...) -end +replacenode(expr::Num, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) +replacenode(expr::Symbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) +replacenode(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint) +replacenode(expr::Number, rules...; fixpoint = false) = expr +replacenode(expr::Number, r::Pair, rules::Pair...; fixpoint = false) = expr -function Base.replace(expr::Symbolic, r::Pair, rules::Pair...) - _replace(expr, r, rules...) -end - -function _replace(expr::Symbolic, rules...; fixpoint=false) +function _replacenode(expr::Symbolic, rules...; fixpoint = false) rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules) R = Prewalk(Chain(rs)) if fixpoint @@ -40,7 +33,7 @@ function _replace(expr::Symbolic, rules...; fixpoint=false) end """ - occursin(c, x) + hasnode(c, x) Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression. If it is a function, returns true if x is true for any part of x. If c is an expression, returns true if x contains c. @@ -48,24 +41,24 @@ true if x contains c. Examples: ```julia @syms x y -Symbolics.occursin(x, log(x) + x + 1) # returns `true`. -Symbolics.occursin(x, log(y) + y + 1) # returns `false`. +hasnode(x, log(x) + x + 1) # returns `true`. +hasnode(x, log(y) + y + 1) # returns `false`. ``` ```julia @variables t X(t) D = Differential(t) -Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`. +hasnode(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`. ``` """ -function Base.occursin(r::Function, y::Union{Num, Symbolic}) - _occursin(r, y) +function hasnode(r::Function, y::Union{Num, Symbolic}) + _hasnode(r, y) end +hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Union{Num, Symbolic, Function}, y::Number) = false -Base.occursin(r::Num, y::Num) = occursin(unwrap(r), unwrap(y)) -Base.occursin(r::Num, y::Symbolic) = occursin(unwrap(r), unwrap(y)) - -function _occursin(r, y) +function _hasnode(r, y) y = unwrap(y) if r isa Function if r(y) @@ -75,7 +68,7 @@ function _occursin(r, y) if iscall(y) return r(operation(y)) || - any(y->_occursin(r, y), arguments(y)) + any(y->_hasnode(r, y), arguments(y)) else return false end @@ -129,6 +122,7 @@ function filterchildren!(r::Any, y, acc) end module RewriteHelpers -import Symbolics: filterchildren, unwrap -export replace, occursin, filterchildren, unwrap +import Symbolics: replacenode, hasnode, filterchildren, unwrap +export replacenode, hasnode, filterchildren, unwrap + end diff --git a/test/rewrite_helpers.jl b/test/rewrite_helpers.jl index d74abe8e5..1a804ac5a 100644 --- a/test/rewrite_helpers.jl +++ b/test/rewrite_helpers.jl @@ -9,15 +9,26 @@ using Test D = Differential(t) my_f(x, y) = x^3 + 2y -# Check replace function. +# Check `replacenode` function. let - @test isequal(replace(X + X + X, X =>1), 3) - @test isequal(replace(X + X + X, Y => 1), 3X) - @test isequal(replace(X + X + X, X => Y), 3Y) - @test isequal(replace(X + Y^2 - Z, Y^2 => Z), X) + # Simple replacements. + @test isequal(replacenode(X + X + X, X =>1), 3) + @test isequal(replacenode(X + X + X, Y => 1), 3X) + @test isequal(replacenode(X + X + my_f(X, Z), X => Y), Y^3 + 2Y + 2Z) + @test isequal(replacenode(X + Y^2 - Z, Y^2 => Z), X) + + # When the rule is a function. + rep_func(expr) = Symbolics.is_derivative(expr) ? b : expr + @test isequal(replacenode(D(X + Y) - log(a*Z), rep_func), b - log(a*Z)) + @test isequal(replacenode(D(Z^2) + my_f(D(X), D(Y)) + Z, rep_func), b^3 + 3b + Z) + @test isequal(replacenode(X + sin(Y + a) + a, rep_func), X + sin(Y + a) + a) + + # On non-symbolic inputs. + @test isequal(replacenode(1, X =>2.0), 1) + @test isequal(replacenode(1, rep_func), 1) end -# Test occursin function. +# Test `hasnode` function. let ex1 = 2X^a - log(b + my_f(Y,Y)) - 3 ex2 = X^(Y^(Z-a)) +log(log(log(b))) @@ -26,39 +37,44 @@ let ex5 = a + 5b^2 # Test for variables. - @test occursin(X, ex1) - @test occursin(X, ex2) - @test occursin(X, ex3) - @test !occursin(X, ex4) - @test occursin(Y, ex1) - @test occursin(Y, ex2) - @test occursin(Y, ex3) - @test occursin(Y, ex4) - @test !occursin(Z, ex1) - @test occursin(Z, ex2) - @test !occursin(Z, ex3) - @test occursin(Z, ex4) + @test hasnode(X, ex1) + @test hasnode(X, ex2) + @test hasnode(X, ex3) + @test !hasnode(X, ex4) + @test hasnode(Y, ex1) + @test hasnode(Y, ex2) + @test hasnode(Y, ex3) + @test hasnode(Y, ex4) + @test !hasnode(Z, ex1) + @test hasnode(Z, ex2) + @test !hasnode(Z, ex3) + @test hasnode(Z, ex4) # Test for variables. - @test_broken occursin(a, ex1) - @test_broken occursin(a, ex2) - @test_broken occursin(a, ex3) - @test_broken occursin(a, ex4) - @test occursin(a, ex5) - @test_broken occursin(b, ex1) - @test_broken occursin(b, ex2) - @test !occursin(b, ex3) - @test !occursin(b, ex4) - @test occursin(b, ex5) + @test hasnode(a, ex1) + @test hasnode(a, ex2) + @test hasnode(a, ex3) + @test hasnode(a, ex4) + @test hasnode(a, ex5) + @test hasnode(b, ex1) + @test hasnode(b, ex2) + @test !hasnode(b, ex3) + @test !hasnode(b, ex4) + @test hasnode(b, ex5) # Test for function. - @test !occursin(is_derivative, ex1) - @test !occursin(is_derivative, ex2) - @test !occursin(is_derivative, ex3) - @test occursin(is_derivative, ex4) + @test !hasnode(is_derivative, ex1) + @test !hasnode(is_derivative, ex2) + @test !hasnode(is_derivative, ex3) + @test hasnode(is_derivative, ex4) + + # On non symbolic inputs: + @test !hasnode(X, 1) + @test !hasnode(a, 1) + @test !hasnode(is_derivative, 1) end -# Check filterchildren function. +# Check `filterchildren` function. let ex1 = 2X^a - log(b + my_f(Y,Y)) - 3 ex2 = X^(Y^(Z-a)) +log(log(log(b))) @@ -80,8 +96,7 @@ let @test isequal(filterchildren(Z, ex3), []) @test isequal(filterchildren(Z, ex4), [Z]) - # Test for variables. - + # Test for syms. @test isequal(filterchildren(a, ex1), [a]) @test isequal(filterchildren(a, ex2), [a]) @test isequal(filterchildren(a, ex3), [a])