diff --git a/src/rewrite-helpers.jl b/src/rewrite-helpers.jl index f8655247d..4ee8a65ef 100644 --- a/src/rewrite-helpers.jl +++ b/src/rewrite-helpers.jl @@ -35,18 +35,30 @@ function Base.replace(expr::Symbolic, r, rules...) end """ - occursin(x, y) -Checks whether `x` occurs in `y`. Parses `y`, returning `true` on any occurrence of x. + occursin(c, x) +Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression. +If it is a function, returns true if x is true for any part of x. If c is an expression, returns +true if x contains c. -Example: +Examples: ```julia -@syms x y z -occursin(x, y + z*(3+x)) # returns `true`. -occursin(x, log(y) + 10*(z-y) # returns `false`. +@syms x y +Symbolics.occursin(x, log(x) + x + 1) # returns `true`. +Symbolics.occursin(x, log(y) + y + 1) # returns `false`. +``` + +```julia +@variables t X(t) +D = Differential(t) +Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`. ``` """ Base.occursin(x::Num, y::Num) = occursin(unwrap(x), unwrap(y)) @wrapped function Base.occursin(r::Any, y::Real) + _occursin(r, y) +end + +function _occursin(r, y) y = unwrap(y) if isequal(r, y) return true @@ -58,14 +70,15 @@ Base.occursin(x::Num, y::Num) = occursin(unwrap(x), unwrap(y)) if istree(y) return r(operation(y)) || - any(y->occursin(r, y), arguments(y)) + any(y->_occursin(r, y), arguments(y)) else return false end end -function filterchildren!(r::Any, y::Union{Num, Symbolic}, acc) +function filterchildren!(r::Any, y, acc) y = unwrap(y) + r = unwrap(r) if isequal(r, y) push!(acc, y) return acc @@ -77,8 +90,10 @@ function filterchildren!(r::Any, y::Union{Num, Symbolic}, acc) end if istree(y) - if r(operation(y)) - push!(acc, y) + if isequal(r, operation(y)) + push!(acc, operation(y)) + elseif r isa Function && r(operation(y)) + push!(acc, operation(y)) end foreach(c->filterchildren!(r, c, acc), arguments(y)) @@ -86,6 +101,26 @@ function filterchildren!(r::Any, y::Union{Num, Symbolic}, acc) end end +""" +filterchildren(c, x) +Returns all parts of `x` that fufills the condition given in c. c can be a function or an expression. +If it is a function, returns everything for which the function is `true`. If c is an expression, returns +all expressions that matches it. + +Examples: +```julia +@syms x +Symbolics.filterchildren(x, log(x) + x + 1) +``` +returns `[x, x]` + +```julia +@variables t X(t) +D = Differential(t) +Symbolics.filterchildren(Symbolics.is_derivative, X + D(X) + D(X^2)) +``` +returns `[Differential(t)(X(t)^2), Differential(t)(X(t))]` +""" filterchildren(r, y) = filterchildren!(r, y, []) module RewriteHelpers