Skip to content

Commit

Permalink
replace -> repalcenode
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jun 4, 2024
1 parent 09d82d1 commit a8e987f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
27 changes: 15 additions & 12 deletions src/rewrite-helpers.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
"""
replace(expr::Symbolic, rules...)
Walk the expression and replace subexpressions according to `rules`. `rules`
replacenode(expr::Symbolic, rules...)
Walk the expression and replacenode subexpressions according to `rules`. `rules`
could be rules constructed with `@rule`, a function, or a pair where the
left hand side is matched with equality (using `isequal`) and is replaced by the right hand side.
left hand side is matched with equality (using `isequal`) and is replacenoded by the right hand side.
Rules will be applied left-to-right simultaneously,
so only one pattern will be applied to any subexpression,
and the patterns will only be applied to the input text,
not the replacements.
not the replacenodements.
Set `fixpoint = true` to repeatedly apply rules until no
change to the expression remains to be made.
"""
function replace(expr::Num, r::Pair, rules::Pair...; fixpoint = false)
_replace(unwrap(expr), r, rules...)
function replacenode(expr::Num, r::Pair, rules::Pair...; fixpoint = false)
_replacenode(unwrap(expr), r, rules...)
end
# Fix ambiguity
replace(expr::Num, rules...; fixpoint = false) = _replace(unwrap(expr), rules...; fixpoint)
replace(expr::Symbolic, rules...; fixpoint = false) = _replace(unwrap(expr), rules...; fixpoint)
replace(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replace(expr, r, rules...; fixpoint)
replacenode(expr::Num, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint)
replacenode(expr::Symbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint)
replacenode(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint)
replacenode(expr::Number, rules...; fixpoint = false) = expr
replacenode(expr::Number, r::Pair, rules::Pair...; fixpoint = false) = expr

function _replace(expr::Symbolic, rules...; fixpoint = false)
function _replacenode(expr::Symbolic, rules...; fixpoint = false)
rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules)
R = Prewalk(Chain(rs))
if fixpoint
Expand Down Expand Up @@ -54,6 +56,7 @@ function hasnode(r::Function, y::Union{Num, Symbolic})
end
hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))
hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))
hasnode(r::Number, y::Union{Num, Symbolic}) = false

function _hasnode(r, y)
y = unwrap(y)
Expand Down Expand Up @@ -119,7 +122,7 @@ function filterchildren!(r::Any, y, acc)
end

module RewriteHelpers
import Symbolics: replace, hasnode, filterchildren, unwrap
export replace, hasnode, filterchildren, unwrap
import Symbolics: replacenode, hasnode, filterchildren, unwrap
export replacenode, hasnode, filterchildren, unwrap

end
26 changes: 18 additions & 8 deletions test/rewrite_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@ using Test
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Check `replace` function.
# Check `replacenode` function.
let
# Simple replacements.
@test isequal(Symbolics.replace(X + X + X, X =>1), 3)
@test isequal(Symbolics.replace(X + X + X, Y => 1), 3X)
@test isequal(Symbolics.replace(X + X + X, X => Y), 3Y)
@test isequal(Symbolics.replace(X + Y^2 - Z, Y^2 => Z), X)
@test isequal(replacenode(X + X + X, X =>1), 3)
@test isequal(replacenode(X + X + X, Y => 1), 3X)
@test isequal(replacenode(X + X + X, X => Y), 3Y)
@test isequal(replacenode(X + Y^2 - Z, Y^2 => Z), X)

# When the rule is a function.
rep_func(expr) = Symbolics.is_derivative(expr) ? b : expr
@test isequal(Symbolics.replace(D(X + Y) - log(a*Z), rep_func), b - log(a*Z))
@test isequal(Symbolics.replace(D(Z^2) + D(X + Y) + Z, rep_func), 2b + Z)
@test isequal(Symbolics.replace(X + sin(Y + a) + a, rep_func), X + sin(Y + a) + a)
@test isequal(replacenode(D(X + Y) - log(a*Z), rep_func), b - log(a*Z))
@test isequal(replacenode(D(Z^2) + D(X + Y) + Z, rep_func), 2b + Z)
@test isequal(replacenode(X + sin(Y + a) + a, rep_func), X + sin(Y + a) + a)

# On non-symbolic inputs.
@test isequal(replacenode(1, X =>2.0), 1)
@test isequal(replacenode(1, rep_func(expr)), 1)
end

# Test `hasnode` function.
Expand Down Expand Up @@ -63,6 +67,12 @@ let
@test !hasnode(is_derivative, ex2)
@test !hasnode(is_derivative, ex3)
@test hasnode(is_derivative, ex4)

# On non symbolic inputs:
@test !hasnode(X, 1)
@test !hasnode(a, 1)
@test !hasnode(is_derivative, 1)

end

# Check `filterchildren` function.
Expand Down

0 comments on commit a8e987f

Please sign in to comment.