Skip to content

Commit

Permalink
Merge pull request #1026 from TorkelE/tests_for_rewrite_hellper
Browse files Browse the repository at this point in the history
Tests for rewrite hellper
  • Loading branch information
ChrisRackauckas authored Mar 18, 2024
2 parents c657268 + 4d37631 commit df632fd
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/src/manual/derivatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ the differentials down to basic one-variable expressions.
```@docs
Differential
expand_derivatives
is_derivative
```

!!! note
Expand Down
3 changes: 3 additions & 0 deletions docs/src/manual/expression_manipulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ Symbolics.diff2term
Symbolics.solve_for
Symbolics.degree
Symbolics.coeff
Symbolics.replace
Symbolics.occursin
Symbolics.filterchildren
```
3 changes: 2 additions & 1 deletion src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines
include("wrapper-types.jl")

include("num.jl")
include("rewrite-helpers.jl")
include("complex.jl")

"""
Expand Down Expand Up @@ -96,7 +97,7 @@ using DiffRules, SpecialFunctions, NaNMath

using SparseArrays

export Differential, expand_derivatives
export Differential, expand_derivatives, is_derivative

include("diff.jl")

Expand Down
134 changes: 134 additions & 0 deletions src/rewrite-helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
replace(expr::Symbolic, rules...)
Walk the expression and replace subexpressions according to `rules`. `rules`
could be rules constructed with `@rule`, a function, or a pair where the
left hand side is matched with equality (using `isequal`) and is replaced by the right hand side.
Rules will be applied left-to-right simultaneously,
so only one pattern will be applied to any subexpression,
and the patterns will only be applied to the input text,
not the replacements.
Set `fixpoint = true` to repeatedly apply rules until no
change to the expression remains to be made.
"""
function Base.replace(expr::Num, r::Pair, rules::Pair...)
_replace(unwrap(expr), r, rules...)
end

# Fix ambiguity
function Base.replace(expr::Num, rules...)
_replace(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, rules...)
_replace(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, r::Pair, rules::Pair...)
_replace(expr, r, rules...)
end

function _replace(expr::Symbolic, rules...; fixpoint=false)
rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules)
R = Prewalk(Chain(rs))
if fixpoint
Fixpoint(R)(expr)
else
R(expr)
end
end

"""
occursin(c, x)
Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression.
If it is a function, returns true if x is true for any part of x. If c is an expression, returns
true if x contains c.
Examples:
```julia
@syms x y
Symbolics.occursin(x, log(x) + x + 1) # returns `true`.
Symbolics.occursin(x, log(y) + y + 1) # returns `false`.
```
```julia
@variables t X(t)
D = Differential(t)
Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
```
"""
function Base.occursin(r::Function, y::Union{Num, Symbolic})
_occursin(r, y)
end

Base.occursin(r::Num, y::Num) = occursin(unwrap(r), unwrap(y))
Base.occursin(r::Num, y::Symbolic) = occursin(unwrap(r), unwrap(y))

function _occursin(r, y)
y = unwrap(y)
if r isa Function
if r(y)
return true
end
end

if istree(y)
return r(operation(y)) ||
any(y->_occursin(r, y), arguments(y))
else
return false
end
end

"""
filterchildren(c, x)
Returns all parts of `x` that fufills the condition given in c. c can be a function or an expression.
If it is a function, returns everything for which the function is `true`. If c is an expression, returns
all expressions that matches it.
Examples:
```julia
@syms x
Symbolics.filterchildren(x, log(x) + x + 1)
```
returns `[x, x]`
```julia
@variables t X(t)
D = Differential(t)
Symbolics.filterchildren(Symbolics.is_derivative, X + D(X) + D(X^2))
```
returns `[Differential(t)(X(t)^2), Differential(t)(X(t))]`
"""
filterchildren(r, y) = filterchildren!(r, y, [])

function filterchildren!(r::Any, y, acc)
y = unwrap(y)
r = unwrap(r)
if isequal(r, y)
push!(acc, y)
return acc
elseif r isa Function
if r(y)
push!(acc, y)
return acc
end
end

if istree(y)
if isequal(r, operation(y))
push!(acc, operation(y))
elseif r isa Function && r(operation(y))
push!(acc, operation(y))
end
foreach(c->filterchildren!(r, c, acc),
arguments(y))
return acc
end
end

module RewriteHelpers
import Symbolics: filterchildren, unwrap
export replace, occursin, filterchildren, unwrap
end
23 changes: 23 additions & 0 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,26 @@ let
@test isequal(expand_derivatives(Differential(t)(im*t)), im)
@test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im)
end


# Check `is_derivative` function
let
@variables t X(t) Y(t)
@syms a b
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Single expressions.
@test !Symbolics.is_derivative(Symbolics.unwrap(D))
@test !Symbolics.is_derivative(Symbolics.unwrap(t))
@test !Symbolics.is_derivative(Symbolics.unwrap(X))
@test !Symbolics.is_derivative(Symbolics.unwrap(a))
@test !Symbolics.is_derivative(Symbolics.unwrap(1))

# Composite expressions.
@test Symbolics.is_derivative(Symbolics.unwrap(D(X)))
@test !Symbolics.is_derivative(Symbolics.unwrap(D(X) + 3))
@test Symbolics.is_derivative(Symbolics.unwrap(D(X + 2a*Y)))
@test !Symbolics.is_derivative(Symbolics.unwrap(D(X) + D(Y)))
@test !Symbolics.is_derivative(Symbolics.unwrap(my_f(X, D(Y))))
end
102 changes: 102 additions & 0 deletions test/rewrite_helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

# Fetch packages.
using Symbolics
using Symbolics.RewriteHelpers
using Test

@variables t X(t) Y(t) Z(t)
@syms a b
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Check replace function.
let
@test isequal(replace(X + X + X, X =>1), 3)
@test isequal(replace(X + X + X, Y => 1), 3X)
@test isequal(replace(X + X + X, X => Y), 3Y)
@test isequal(replace(X + Y^2 - Z, Y^2 => Z), X)
end

# Test occursin function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
ex3 = sin(X) + sin(Y) + a*a*a*(1-X)
ex4 = exp(a)/(pi*a) + D(Y) + D(my_f(1,Z))
ex5 = a + 5b^2

# Test for variables.
@test occursin(X, ex1)
@test occursin(X, ex2)
@test occursin(X, ex3)
@test !occursin(X, ex4)
@test occursin(Y, ex1)
@test occursin(Y, ex2)
@test occursin(Y, ex3)
@test occursin(Y, ex4)
@test !occursin(Z, ex1)
@test occursin(Z, ex2)
@test !occursin(Z, ex3)
@test occursin(Z, ex4)

# Test for variables.
@test_broken occursin(a, ex1)
@test_broken occursin(a, ex2)
@test_broken occursin(a, ex3)
@test_broken occursin(a, ex4)
@test occursin(a, ex5)
@test_broken occursin(b, ex1)
@test_broken occursin(b, ex2)
@test !occursin(b, ex3)
@test !occursin(b, ex4)
@test occursin(b, ex5)

# Test for function.
@test !occursin(is_derivative, ex1)
@test !occursin(is_derivative, ex2)
@test !occursin(is_derivative, ex3)
@test occursin(is_derivative, ex4)
end

# Check filterchildren function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
ex3 = sin(X) + sin(Y) + a*a*a*(1-X)
ex4 = exp(a)/(pi*a) + D(Y) + D(my_f(1,Z))
ex5 = a + 5b^2

# Test for variables.
@test isequal(filterchildren(X, ex1), [X])
@test isequal(filterchildren(X, ex2), [X])
@test isequal(filterchildren(X, ex3), [X, X])
@test isequal(filterchildren(X, ex4), [])
@test isequal(filterchildren(Y, ex1), [Y, Y])
@test isequal(filterchildren(Y, ex2), [Y])
@test isequal(filterchildren(Y, ex3), [Y])
@test isequal(filterchildren(Y, ex4), [Y])
@test isequal(filterchildren(Z, ex1), [])
@test isequal(filterchildren(Z, ex2), [Z])
@test isequal(filterchildren(Z, ex3), [])
@test isequal(filterchildren(Z, ex4), [Z])

# Test for variables.

@test isequal(filterchildren(a, ex1), [a])
@test isequal(filterchildren(a, ex2), [a])
@test isequal(filterchildren(a, ex3), [a])
@test isequal(filterchildren(a, ex4), [a, a])
@test isequal(filterchildren(a, ex5), [a])
@test isequal(filterchildren(b, ex1), [b])
@test isequal(filterchildren(b, ex2), [b])
@test isequal(filterchildren(b, ex3), [])
@test isequal(filterchildren(b, ex4), [])
@test isequal(filterchildren(b, ex5), [b])

# Test for function.
@test isequal(filterchildren(is_derivative, ex1), [])
@test isequal(filterchildren(is_derivative, ex2), [])
@test isequal(filterchildren(is_derivative, ex3), [])
@test isequal(filterchildren(is_derivative, ex4), [D(Y), D(my_f(1,Z))])
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Build Function Test" begin include("build_function.jl") end
@safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end
@safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end
@safetestset "Rewrite Helper Function Test" begin include("rewrite_helpers.jl") end
VERSION >= v"1.9" && @safetestset "Build Targets Test" begin include("build_targets.jl") end
@safetestset "Latexify Test" begin include("latexify.jl") end
@safetestset "Domain Test" begin include("domains.jl") end
Expand Down

0 comments on commit df632fd

Please sign in to comment.