Skip to content

Commit

Permalink
replacenode -> replace
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jun 3, 2024
1 parent 2602483 commit 0197338
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
18 changes: 9 additions & 9 deletions src/rewrite-helpers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
replacenode(expr::Symbolic, rules...)
replace(expr::Symbolic, rules...)
Walk the expression and replace subexpressions according to `rules`. `rules`
could be rules constructed with `@rule`, a function, or a pair where the
left hand side is matched with equality (using `isequal`) and is replaced by the right hand side.
Expand All @@ -12,15 +12,15 @@ not the replacements.
Set `fixpoint = true` to repeatedly apply rules until no
change to the expression remains to be made.
"""
function replacenode(expr::Num, r::Pair, rules::Pair...)
_replacenode(unwrap(expr), r, rules...)
function replace(expr::Num, r::Pair, rules::Pair...; fixpoint = false)
_replace(unwrap(expr), r, rules...)
end
# Fix ambiguity
replacenode(expr::Num, rules...) = _replacenode(unwrap(expr), rules...)
replacenode(expr::Symbolic, rules...) = _replacenode(unwrap(expr), rules...)
replacenode(expr::Symbolic, r::Pair, rules::Pair...) = _replacenode(expr, r, rules...)
replace(expr::Num, rules...; fixpoint = false) = _replace(unwrap(expr), rules...; fixpoint)
replace(expr::Symbolic, rules...; fixpoint = false) = _replace(unwrap(expr), rules...; fixpoint)
replace(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replace(expr, r, rules...; fixpoint)

function _replacenode(expr::Symbolic, rules...; fixpoint=false)
function _replace(expr::Symbolic, rules...; fixpoint = false)
rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules)
R = Prewalk(Chain(rs))
if fixpoint
Expand Down Expand Up @@ -119,7 +119,7 @@ function filterchildren!(r::Any, y, acc)
end

module RewriteHelpers
import Symbolics: replacenode, hasnode, filterchildren, unwrap
export replacenode, hasnode, filterchildren, unwrap
import Symbolics: replace, hasnode, filterchildren, unwrap
export replace, hasnode, filterchildren, unwrap

end
14 changes: 7 additions & 7 deletions test/rewrite_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ using Test
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Check replacenode function.
# Check `replace` function.
let
@test isequal(replacenode(X + X + X, X =>1), 3)
@test isequal(replacenode(X + X + X, Y => 1), 3X)
@test isequal(replacenode(X + X + X, X => Y), 3Y)
@test isequal(replacenode(X + Y^2 - Z, Y^2 => Z), X)
@test isequal(replace(X + X + X, X =>1), 3)
@test isequal(replace(X + X + X, Y => 1), 3X)
@test isequal(replace(X + X + X, X => Y), 3Y)
@test isequal(replace(X + Y^2 - Z, Y^2 => Z), X)
end

# Test hasnode function.
# Test `hasnode` function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
Expand Down Expand Up @@ -58,7 +58,7 @@ let
@test hasnode(is_derivative, ex4)
end

# Check filterchildren function.
# Check `filterchildren` function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
Expand Down

0 comments on commit 0197338

Please sign in to comment.