Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests for rewrite hellper #1026

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/manual/derivatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ the differentials down to basic one-variable expressions.
```@docs
Differential
expand_derivatives
is_derivative
```

!!! note
Expand Down
3 changes: 3 additions & 0 deletions docs/src/manual/expression_manipulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ Symbolics.diff2term
Symbolics.solve_for
Symbolics.degree
Symbolics.coeff
Symbolics.replace
Symbolics.occursin
Symbolics.filterchildren
```
3 changes: 2 additions & 1 deletion src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines
include("wrapper-types.jl")

include("num.jl")
include("rewrite-helpers.jl")
include("complex.jl")

"""
Expand Down Expand Up @@ -96,7 +97,7 @@ using DiffRules, SpecialFunctions, NaNMath

using SparseArrays

export Differential, expand_derivatives
export Differential, expand_derivatives, is_derivative

include("diff.jl")

Expand Down
134 changes: 134 additions & 0 deletions src/rewrite-helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably move ModelingToolkit.fixedpoint_sub here as well?

replace(expr::Symbolic, rules...)
Walk the expression and replace subexpressions according to `rules`. `rules`
could be rules constructed with `@rule`, a function, or a pair where the
left hand side is matched with equality (using `isequal`) and is replaced by the right hand side.

Rules will be applied left-to-right simultaneously,
so only one pattern will be applied to any subexpression,
and the patterns will only be applied to the input text,
not the replacements.

Set `fixpoint = true` to repeatedly apply rules until no
change to the expression remains to be made.
"""
function Base.replace(expr::Num, r::Pair, rules::Pair...)
_replace(unwrap(expr), r, rules...)
end

# Fix ambiguity
function Base.replace(expr::Num, rules...)
_replace(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, rules...)
_replace(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, r::Pair, rules::Pair...)
_replace(expr, r, rules...)
end

function _replace(expr::Symbolic, rules...; fixpoint=false)
rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules)
R = Prewalk(Chain(rs))
if fixpoint
Fixpoint(R)(expr)
else
R(expr)
end
end

"""
occursin(c, x)
Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression.
If it is a function, returns true if x is true for any part of x. If c is an expression, returns
true if x contains c.

Examples:
```julia
@syms x y
Symbolics.occursin(x, log(x) + x + 1) # returns `true`.
Symbolics.occursin(x, log(y) + y + 1) # returns `false`.
```

```julia
@variables t X(t)
D = Differential(t)
Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
```
"""
function Base.occursin(r::Function, y::Union{Num, Symbolic})
_occursin(r, y)
end

Base.occursin(r::Num, y::Num) = occursin(unwrap(r), unwrap(y))
Base.occursin(r::Num, y::Symbolic) = occursin(unwrap(r), unwrap(y))

function _occursin(r, y)
y = unwrap(y)
if r isa Function
if r(y)
return true
end
end

if istree(y)
return r(operation(y)) ||
any(y->_occursin(r, y), arguments(y))
else
return false
end
end

"""
filterchildren(c, x)
Returns all parts of `x` that fufills the condition given in c. c can be a function or an expression.
If it is a function, returns everything for which the function is `true`. If c is an expression, returns
all expressions that matches it.

Examples:
```julia
@syms x
Symbolics.filterchildren(x, log(x) + x + 1)
```
returns `[x, x]`

```julia
@variables t X(t)
D = Differential(t)
Symbolics.filterchildren(Symbolics.is_derivative, X + D(X) + D(X^2))
```
returns `[Differential(t)(X(t)^2), Differential(t)(X(t))]`
"""
filterchildren(r, y) = filterchildren!(r, y, [])

function filterchildren!(r::Any, y, acc)
y = unwrap(y)
r = unwrap(r)
if isequal(r, y)
push!(acc, y)
return acc
elseif r isa Function
if r(y)
push!(acc, y)
return acc
end
end

if istree(y)
if isequal(r, operation(y))
push!(acc, operation(y))
elseif r isa Function && r(operation(y))
push!(acc, operation(y))
end
foreach(c->filterchildren!(r, c, acc),
arguments(y))
return acc
end
end

module RewriteHelpers
import Symbolics: filterchildren, unwrap
export replace, occursin, filterchildren, unwrap
end
23 changes: 23 additions & 0 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,26 @@ let
@test isequal(expand_derivatives(Differential(t)(im*t)), im)
@test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im)
end


# Check `is_derivative` function
let
@variables t X(t) Y(t)
@syms a b
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Single expressions.
@test !Symbolics.is_derivative(Symbolics.unwrap(D))
@test !Symbolics.is_derivative(Symbolics.unwrap(t))
@test !Symbolics.is_derivative(Symbolics.unwrap(X))
@test !Symbolics.is_derivative(Symbolics.unwrap(a))
@test !Symbolics.is_derivative(Symbolics.unwrap(1))

# Composite expressions.
@test Symbolics.is_derivative(Symbolics.unwrap(D(X)))
@test !Symbolics.is_derivative(Symbolics.unwrap(D(X) + 3))
@test Symbolics.is_derivative(Symbolics.unwrap(D(X + 2a*Y)))
@test !Symbolics.is_derivative(Symbolics.unwrap(D(X) + D(Y)))
@test !Symbolics.is_derivative(Symbolics.unwrap(my_f(X, D(Y))))
end
102 changes: 102 additions & 0 deletions test/rewrite_helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

# Fetch packages.
using Symbolics
using Symbolics.RewriteHelpers
using Test

@variables t X(t) Y(t) Z(t)
@syms a b
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Check replace function.
let
@test isequal(replace(X + X + X, X =>1), 3)
@test isequal(replace(X + X + X, Y => 1), 3X)
@test isequal(replace(X + X + X, X => Y), 3Y)
@test isequal(replace(X + Y^2 - Z, Y^2 => Z), X)
end

# Test occursin function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
ex3 = sin(X) + sin(Y) + a*a*a*(1-X)
ex4 = exp(a)/(pi*a) + D(Y) + D(my_f(1,Z))
ex5 = a + 5b^2

# Test for variables.
@test occursin(X, ex1)
@test occursin(X, ex2)
@test occursin(X, ex3)
@test !occursin(X, ex4)
@test occursin(Y, ex1)
@test occursin(Y, ex2)
@test occursin(Y, ex3)
@test occursin(Y, ex4)
@test !occursin(Z, ex1)
@test occursin(Z, ex2)
@test !occursin(Z, ex3)
@test occursin(Z, ex4)

# Test for variables.
@test_broken occursin(a, ex1)
@test_broken occursin(a, ex2)
@test_broken occursin(a, ex3)
@test_broken occursin(a, ex4)
@test occursin(a, ex5)
@test_broken occursin(b, ex1)
@test_broken occursin(b, ex2)
@test !occursin(b, ex3)
@test !occursin(b, ex4)
@test occursin(b, ex5)

# Test for function.
@test !occursin(is_derivative, ex1)
@test !occursin(is_derivative, ex2)
@test !occursin(is_derivative, ex3)
@test occursin(is_derivative, ex4)
end

# Check filterchildren function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
ex3 = sin(X) + sin(Y) + a*a*a*(1-X)
ex4 = exp(a)/(pi*a) + D(Y) + D(my_f(1,Z))
ex5 = a + 5b^2

# Test for variables.
@test isequal(filterchildren(X, ex1), [X])
@test isequal(filterchildren(X, ex2), [X])
@test isequal(filterchildren(X, ex3), [X, X])
@test isequal(filterchildren(X, ex4), [])
@test isequal(filterchildren(Y, ex1), [Y, Y])
@test isequal(filterchildren(Y, ex2), [Y])
@test isequal(filterchildren(Y, ex3), [Y])
@test isequal(filterchildren(Y, ex4), [Y])
@test isequal(filterchildren(Z, ex1), [])
@test isequal(filterchildren(Z, ex2), [Z])
@test isequal(filterchildren(Z, ex3), [])
@test isequal(filterchildren(Z, ex4), [Z])

# Test for variables.

@test isequal(filterchildren(a, ex1), [a])
@test isequal(filterchildren(a, ex2), [a])
@test isequal(filterchildren(a, ex3), [a])
@test isequal(filterchildren(a, ex4), [a, a])
@test isequal(filterchildren(a, ex5), [a])
@test isequal(filterchildren(b, ex1), [b])
@test isequal(filterchildren(b, ex2), [b])
@test isequal(filterchildren(b, ex3), [])
@test isequal(filterchildren(b, ex4), [])
@test isequal(filterchildren(b, ex5), [b])

# Test for function.
@test isequal(filterchildren(is_derivative, ex1), [])
@test isequal(filterchildren(is_derivative, ex2), [])
@test isequal(filterchildren(is_derivative, ex3), [])
@test isequal(filterchildren(is_derivative, ex4), [D(Y), D(my_f(1,Z))])
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Build Function Test" begin include("build_function.jl") end
@safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end
@safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end
@safetestset "Rewrite Helper Function Test" begin include("rewrite_helpers.jl") end
VERSION >= v"1.9" && @safetestset "Build Targets Test" begin include("build_targets.jl") end
@safetestset "Latexify Test" begin include("latexify.jl") end
@safetestset "Domain Test" begin include("domains.jl") end
Expand Down
Loading