Skip to content

Commit

Permalink
Merge pull request #1139 from TorkelE/replaceupdate
Browse files Browse the repository at this point in the history
A fix for `occursin`
  • Loading branch information
shashi authored Jun 4, 2024
2 parents 29aa0fa + dc0f367 commit 8498d6b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 67 deletions.
1 change: 1 addition & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines
include("wrapper-types.jl")

include("num.jl")

include("rewrite-helpers.jl")
include("complex.jl")

Expand Down
58 changes: 26 additions & 32 deletions src/rewrite-helpers.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,28 @@
"""
replace(expr::Symbolic, rules...)
Walk the expression and replace subexpressions according to `rules`. `rules`
replacenode(expr::Symbolic, rules...)
Walk the expression and replacenode subexpressions according to `rules`. `rules`
could be rules constructed with `@rule`, a function, or a pair where the
left hand side is matched with equality (using `isequal`) and is replaced by the right hand side.
left hand side is matched with equality (using `isequal`) and is replacenoded by the right hand side.
Rules will be applied left-to-right simultaneously,
so only one pattern will be applied to any subexpression,
and the patterns will only be applied to the input text,
not the replacements.
not the replacenodements.
Set `fixpoint = true` to repeatedly apply rules until no
change to the expression remains to be made.
"""
function Base.replace(expr::Num, r::Pair, rules::Pair...)
_replace(unwrap(expr), r, rules...)
function replacenode(expr::Num, r::Pair, rules::Pair...; fixpoint = false)
_replacenode(unwrap(expr), r, rules...)
end

# Fix ambiguity
function Base.replace(expr::Num, rules...)
_replace(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, rules...)
_replace(unwrap(expr), rules...)
end
replacenode(expr::Num, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint)
replacenode(expr::Symbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint)
replacenode(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint)
replacenode(expr::Number, rules...; fixpoint = false) = expr
replacenode(expr::Number, r::Pair, rules::Pair...; fixpoint = false) = expr

function Base.replace(expr::Symbolic, r::Pair, rules::Pair...)
_replace(expr, r, rules...)
end

function _replace(expr::Symbolic, rules...; fixpoint=false)
function _replacenode(expr::Symbolic, rules...; fixpoint = false)
rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules)
R = Prewalk(Chain(rs))
if fixpoint
Expand All @@ -40,32 +33,32 @@ function _replace(expr::Symbolic, rules...; fixpoint=false)
end

"""
occursin(c, x)
hasnode(c, x)
Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression.
If it is a function, returns true if x is true for any part of x. If c is an expression, returns
true if x contains c.
Examples:
```julia
@syms x y
Symbolics.occursin(x, log(x) + x + 1) # returns `true`.
Symbolics.occursin(x, log(y) + y + 1) # returns `false`.
hasnode(x, log(x) + x + 1) # returns `true`.
hasnode(x, log(y) + y + 1) # returns `false`.
```
```julia
@variables t X(t)
D = Differential(t)
Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
hasnode(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
```
"""
function Base.occursin(r::Function, y::Union{Num, Symbolic})
_occursin(r, y)
function hasnode(r::Function, y::Union{Num, Symbolic})
_hasnode(r, y)
end
hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))
hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))
hasnode(r::Union{Num, Symbolic, Function}, y::Number) = false

Base.occursin(r::Num, y::Num) = occursin(unwrap(r), unwrap(y))
Base.occursin(r::Num, y::Symbolic) = occursin(unwrap(r), unwrap(y))

function _occursin(r, y)
function _hasnode(r, y)
y = unwrap(y)
if r isa Function
if r(y)
Expand All @@ -75,7 +68,7 @@ function _occursin(r, y)

if iscall(y)
return r(operation(y)) ||
any(y->_occursin(r, y), arguments(y))
any(y->_hasnode(r, y), arguments(y))
else
return false
end
Expand Down Expand Up @@ -129,6 +122,7 @@ function filterchildren!(r::Any, y, acc)
end

module RewriteHelpers
import Symbolics: filterchildren, unwrap
export replace, occursin, filterchildren, unwrap
import Symbolics: replacenode, hasnode, filterchildren, unwrap
export replacenode, hasnode, filterchildren, unwrap

end
85 changes: 50 additions & 35 deletions test/rewrite_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,26 @@ using Test
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Check replace function.
# Check `replacenode` function.
let
@test isequal(replace(X + X + X, X =>1), 3)
@test isequal(replace(X + X + X, Y => 1), 3X)
@test isequal(replace(X + X + X, X => Y), 3Y)
@test isequal(replace(X + Y^2 - Z, Y^2 => Z), X)
# Simple replacements.
@test isequal(replacenode(X + X + X, X =>1), 3)
@test isequal(replacenode(X + X + X, Y => 1), 3X)
@test isequal(replacenode(X + X + my_f(X, Z), X => Y), Y^3 + 2Y + 2Z)
@test isequal(replacenode(X + Y^2 - Z, Y^2 => Z), X)

# When the rule is a function.
rep_func(expr) = Symbolics.is_derivative(expr) ? b : expr
@test isequal(replacenode(D(X + Y) - log(a*Z), rep_func), b - log(a*Z))
@test isequal(replacenode(D(Z^2) + my_f(D(X), D(Y)) + Z, rep_func), b^3 + 3b + Z)
@test isequal(replacenode(X + sin(Y + a) + a, rep_func), X + sin(Y + a) + a)

# On non-symbolic inputs.
@test isequal(replacenode(1, X =>2.0), 1)
@test isequal(replacenode(1, rep_func), 1)
end

# Test occursin function.
# Test `hasnode` function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
Expand All @@ -26,39 +37,44 @@ let
ex5 = a + 5b^2

# Test for variables.
@test occursin(X, ex1)
@test occursin(X, ex2)
@test occursin(X, ex3)
@test !occursin(X, ex4)
@test occursin(Y, ex1)
@test occursin(Y, ex2)
@test occursin(Y, ex3)
@test occursin(Y, ex4)
@test !occursin(Z, ex1)
@test occursin(Z, ex2)
@test !occursin(Z, ex3)
@test occursin(Z, ex4)
@test hasnode(X, ex1)
@test hasnode(X, ex2)
@test hasnode(X, ex3)
@test !hasnode(X, ex4)
@test hasnode(Y, ex1)
@test hasnode(Y, ex2)
@test hasnode(Y, ex3)
@test hasnode(Y, ex4)
@test !hasnode(Z, ex1)
@test hasnode(Z, ex2)
@test !hasnode(Z, ex3)
@test hasnode(Z, ex4)

# Test for variables.
@test_broken occursin(a, ex1)
@test_broken occursin(a, ex2)
@test_broken occursin(a, ex3)
@test_broken occursin(a, ex4)
@test occursin(a, ex5)
@test_broken occursin(b, ex1)
@test_broken occursin(b, ex2)
@test !occursin(b, ex3)
@test !occursin(b, ex4)
@test occursin(b, ex5)
@test hasnode(a, ex1)
@test hasnode(a, ex2)
@test hasnode(a, ex3)
@test hasnode(a, ex4)
@test hasnode(a, ex5)
@test hasnode(b, ex1)
@test hasnode(b, ex2)
@test !hasnode(b, ex3)
@test !hasnode(b, ex4)
@test hasnode(b, ex5)

# Test for function.
@test !occursin(is_derivative, ex1)
@test !occursin(is_derivative, ex2)
@test !occursin(is_derivative, ex3)
@test occursin(is_derivative, ex4)
@test !hasnode(is_derivative, ex1)
@test !hasnode(is_derivative, ex2)
@test !hasnode(is_derivative, ex3)
@test hasnode(is_derivative, ex4)

# On non symbolic inputs:
@test !hasnode(X, 1)
@test !hasnode(a, 1)
@test !hasnode(is_derivative, 1)
end

# Check filterchildren function.
# Check `filterchildren` function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
Expand All @@ -80,8 +96,7 @@ let
@test isequal(filterchildren(Z, ex3), [])
@test isequal(filterchildren(Z, ex4), [Z])

# Test for variables.

# Test for syms.
@test isequal(filterchildren(a, ex1), [a])
@test isequal(filterchildren(a, ex2), [a])
@test isequal(filterchildren(a, ex3), [a])
Expand Down

0 comments on commit 8498d6b

Please sign in to comment.