diff --git a/Project.toml b/Project.toml index d34145063..b2b24b829 100644 --- a/Project.toml +++ b/Project.toml @@ -85,7 +85,7 @@ SpecialFunctions = "2" StaticArrays = "1.1" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.0" -SymbolicUtils = "1.4" +SymbolicUtils = "1.7" julia = "1.10" [extras] diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md index 24f9b95cc..d1a75be20 100644 --- a/docs/src/getting_started.md +++ b/docs/src/getting_started.md @@ -20,7 +20,7 @@ using Symbolics ``` After defining variables as symbolic, symbolic expressions, which we call a -`istree` object, can be generated by utilizing Julia expressions. For example: +`iscall` object, can be generated by utilizing Julia expressions. For example: ```@example symbolic_basics z = x^2 + y @@ -35,7 +35,7 @@ A = [x^2 + y 0 2x y^2 + x 0 0] ``` -Note that by default, `@variables` returns `Sym` or `istree` objects wrapped in +Note that by default, `@variables` returns `Sym` or `iscall` objects wrapped in `Num` to make them behave like subtypes of `Real`. Any operation on these `Num` objects will return a new `Num` object, wrapping the result of computing symbolically on the underlying values. diff --git a/docs/src/manual/variables.md b/docs/src/manual/variables.md index ee05c74d4..97d9b7266 100644 --- a/docs/src/manual/variables.md +++ b/docs/src/manual/variables.md @@ -3,7 +3,7 @@ Symbolics IR mirrors the Julia AST but allows for easy mathematical manipulation by itself following mathematical semantics. The base of the IR is the `Sym` type, which defines a symbolic variable. Registered (mathematical) -functions on `Sym`s (or `istree` objects) return an expression that `istree`. +functions on `Sym`s (or `iscall` objects) return an expression that `iscall`. For example, `op1 = x+y` is one symbolic object and `op2 = 2z` is another, and so `op1*op2` is another tree object. Then, at the top, an `Equation`, normally written as `op1 ~ op2`, defines the symbolic equality between two operations. @@ -13,10 +13,10 @@ written as `op1 ~ op2`, defines the symbolic equality between two operations. `Sym`, `Term`, and `FnType` are from [SymbolicUtils.jl](https://symbolicutils.juliasymbolics.org/api/). Note that in Symbolics, we always use `Sym{Real}`, `Term{Real}`, and -`FnType{Tuple{Any}, Real}`. To get the arguments of an `istree` object, use +`FnType{Tuple{Any}, Real}`. To get the arguments of an `iscall` object, use `arguments(t::Term)`, and to get the operation, use `operation(t::Term)`. However, note that one should never dispatch on `Term` or test `isa Term`. -Instead, one needs to use `SymbolicUtils.istree` to check if `arguments` and +Instead, one needs to use `SymbolicUtils.iscall` to check if `arguments` and `operation` is defined. ```@docs @@ -80,7 +80,7 @@ Control flow can be expressed in Symbolics.jl in the following ways: ## Inspection Functions ```@docs -SymbolicUtils.istree +SymbolicUtils.iscall SymbolicUtils.operation SymbolicUtils.arguments ``` diff --git a/ext/SymbolicsSymPyExt.jl b/ext/SymbolicsSymPyExt.jl index 630cd3ce7..9a92e06d7 100644 --- a/ext/SymbolicsSymPyExt.jl +++ b/ext/SymbolicsSymPyExt.jl @@ -9,13 +9,13 @@ end using Symbolics: value using Symbolics.SymbolicUtils -using SymbolicUtils: istree, operation, arguments, symtype, +using SymbolicUtils: iscall, operation, arguments, symtype, FnType, Symbolic function Symbolics.symbolics_to_sympy(expr) expr = value(expr) expr isa Symbolic || return expr - if istree(expr) + if iscall(expr) sop = symbolics_to_sympy(operation(expr)) sargs = map(symbolics_to_sympy, arguments(expr)) if sop === (^) && length(sargs) == 2 && sargs[2] isa Number diff --git a/src/Symbolics.jl b/src/Symbolics.jl index d316fe608..f1c4f6ee8 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -19,7 +19,7 @@ using PrecompileTools import DomainSets: Domain - import SymbolicUtils: similarterm, istree, operation, arguments, symtype, metadata + import SymbolicUtils: similarterm, iscall, operation, arguments, symtype, metadata import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic, FnType, @rule, Rewriters, substitute, diff --git a/src/array-lib.jl b/src/array-lib.jl index cc5923032..242d8eff8 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -20,7 +20,7 @@ end function Base.getindex(x::SymArray, idx...) idx = unwrap.(idx) meta = metadata(unwrap(x)) - if istree(x) && (op = operation(x)) isa Operator + if iscall(x) && (op = operation(x)) isa Operator args = arguments(x) return op(only(args)[idx...]) elseif shape(x) !== Unknown() && all(i -> i isa Integer, idx) @@ -111,7 +111,7 @@ end import Base: +, -, * tup(c::CartesianIndex) = Tuple(c) -tup(c::Symbolic{CartesianIndex}) = istree(c) ? arguments(c) : error("Cartesian index not found") +tup(c::Symbolic{CartesianIndex}) = iscall(c) ? arguments(c) : error("Cartesian index not found") @wrapped function -(x::CartesianIndex, y::CartesianIndex) CartesianIndex((tup(x) .- tup(y))...) @@ -251,7 +251,7 @@ isadjointvec(A::Adjoint) = ndims(parent(A)) == 1 isadjointvec(A::Transpose) = ndims(parent(A)) == 1 function isadjointvec(A) - if istree(A) + if iscall(A) (operation(A) === (adjoint) || operation(A) == (transpose)) && ndims(arguments(A)[1]) == 1 else @@ -305,7 +305,7 @@ function _matvec(A, b) end @wrapped (*)(A::AbstractMatrix, b::AbstractVector) = _matvec(A, b) -# specialize `dot` to dispatch on `Symbolic{<:Number}` to eventually work for +# specialize `dot` to dispatch on `Symbolic{<:Number}` to eventually work for # arrays of (possibly unwrapped) Symbolic types, see issue #831 @wrapped LinearAlgebra.dot(x::Number, y::Number) = conj(x) * y diff --git a/src/arrays.jl b/src/arrays.jl index 3b2f069f2..242cb05fe 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -94,7 +94,7 @@ shape(aop::ArrayOp) = aop.shape const show_arrayop = Ref{Bool}(false) function Base.show(io::IO, aop::ArrayOp) - if istree(aop.term) && !show_arrayop[] + if iscall(aop.term) && !show_arrayop[] show(io, aop.term) else print(io, "@arrayop") @@ -117,7 +117,7 @@ function Base.showarg(io::IO, aop::ArrayOp, toplevel) end symtype(a::ArrayOp{T}) where {T} = T -istree(a::ArrayOp) = true +iscall(a::ArrayOp) = true function operation(a::ArrayOp) isnothing(a.term) ? typeof(a) : operation(a.term) end @@ -332,7 +332,7 @@ function get_extents(xs) if all(iszero∘wrap, boundaries) get(first(xs)) else - ii = findfirst(x->issym(x) || istree(x), boundaries) + ii = findfirst(x->issym(x) || iscall(x), boundaries) if !isnothing(ii) error("Could not find the boundary from symbolic index $(xs[ii]). Please manually specify the range of indices.") end @@ -355,11 +355,11 @@ get_extents(x::AbstractRange) = x # boundary: how much padding is this indexing requiring, for example # boundary is 2 for x[i + 2], and boundary = -2 for x[i - 2] function idx_to_axes(expr, dict=Dict{Any, Vector}(), ranges=Dict()) - if istree(expr) + if iscall(expr) if operation(expr) === (getindex) args = arguments(expr) for (axis, idx_expr) in enumerate(@views args[2:end]) - if issym(idx_expr) || istree(idx_expr) + if issym(idx_expr) || iscall(idx_expr) vs = get_variables(idx_expr) isempty(vs) && continue sym = only(get_variables(idx_expr)) @@ -529,9 +529,9 @@ wrapper_type(::Type{<:AbstractVector{T}}) where {T} = Arr{maybewrap(T), 1} function Base.show(io::IO, arr::Arr) x = unwrap(arr) - istree(x) && print(io, "(") + iscall(x) && print(io, "(") print(io, unwrap(arr)) - istree(x) && print(io, ")") + iscall(x) && print(io, ")") if !(shape(x) isa Unknown) print(io, "[", join(string.(axes(arr)), ","), "]") end @@ -618,7 +618,7 @@ function replace_by_scalarizing(ex, dict) end function rewrite_operation(x) - if istree(x) && istree(operation(x)) + if iscall(x) && iscall(operation(x)) f = operation(x) ff = replace_by_scalarizing(f, dict) if metadata(x) !== nothing @@ -638,7 +638,7 @@ end function prewalk_if(cond, f, t, similarterm) t′ = cond(t) ? f(t) : return t - if istree(t′) + if iscall(t′) return similarterm(t′, operation(t′), map(x->prewalk_if(cond, f, x, similarterm), arguments(t′))) else @@ -652,7 +652,7 @@ function scalarize(arr::AbstractArray, idx) end function scalarize(arr, idx) - if istree(arr) + if iscall(arr) scalarize_op(operation(arr), arr, idx) else error("scalarize is not defined for $arr at idx=$idx") @@ -761,20 +761,20 @@ eval_array_term(op) = eval_array_term(operation(op), op) function scalarize(arr) if arr isa Arr || arr isa Symbolic{<:AbstractArray} - if istree(arr) + if iscall(arr) arr = eval_array_term(arr) end map(Iterators.product(axes(arr)...)) do i scalarize(arr[i...]) # Use arr[i...] here to trigger any getindex hooks end - elseif istree(arr) && operation(arr) == getindex + elseif iscall(arr) && operation(arr) == getindex args = arguments(arr) scalarize(args[1], (args[2:end]...,)) elseif arr isa Num wrap(scalarize(unwrap(arr))) - elseif istree(arr) && symtype(arr) <: Number + elseif iscall(arr) && symtype(arr) <: Number t = similarterm(arr, operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata=metadata(arr)) - istree(t) ? scalarize_op(operation(t), t) : t + iscall(t) ? scalarize_op(operation(t), t) : t else arr end @@ -804,7 +804,7 @@ function arraymaker(T, shape, views, seq...) ArrayMaker{T}(shape, [(views .=> seq)...], nothing) end -istree(x::ArrayMaker) = true +iscall(x::ArrayMaker) = true operation(x::ArrayMaker) = arraymaker arguments(x::ArrayMaker) = [eltype(x), shape(x), map(first, x.sequence), map(last, x.sequence)...] @@ -965,7 +965,7 @@ end function scalarize(x::ArrayMaker, idx) for (vw, arr) in reverse(x.sequence) # last one wins - if any(x->issym(x) || istree(x), idx) + if any(x->issym(x) || iscall(x), idx) return term(getindex, x, idx...) end if all(in.(idx, vw)) @@ -979,7 +979,7 @@ function scalarize(x::ArrayMaker, idx) end end end - if !any(x->issym(x) || istree(x), idx) && all(in.(idx, axes(x))) + if !any(x->issym(x) || iscall(x), idx) && all(in.(idx, axes(x))) throw(UndefRefError()) end @@ -992,7 +992,7 @@ end function SymbolicUtils.Code.toexpr(x::ArrayOp, st) haskey(st.symbolify, x) && return st.symbolify[x] - if istree(x.term) + if iscall(x.term) toexpr(x.term, st) else _array_toexpr(x, st) @@ -1048,7 +1048,7 @@ end function inplace_builtin(term, outsym) isarr(n) = x->symtype(x) <: AbstractArray{<:Any, n} - if istree(term) && operation(term) == (*) && length(arguments(term)) == 2 + if iscall(term) && operation(term) == (*) && length(arguments(term)) == 2 A, B = arguments(term) isarr(2)(A) && (isarr(1)(B) || isarr(2)(B)) && return :($mul!($outsym, $A, $B)) end @@ -1058,7 +1058,7 @@ end function find_inter(acc, expr) if !issym(expr) && symtype(expr) <: AbstractArray push!(acc, expr) - elseif istree(expr) + elseif iscall(expr) foreach(x -> find_inter(acc, x), arguments(expr)) end acc diff --git a/src/build_function.jl b/src/build_function.jl index 4c0b4db8d..b0fd5f288 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -160,7 +160,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; N = length(shape(op)) op = unwrap(op) - if op isa ArrayOp && istree(op.term) + if op isa ArrayOp && iscall(op.term) op_body = op.term else op_body = :(let $outsym = zeros(Float64, map(length, ($(shape(op)...),))) @@ -221,7 +221,7 @@ Build function target: `JuliaTarget` ```julia function _build_function(target::JuliaTarget, rhss, args...; - conv = toexpr, + conv = toexpr, expression = Val{true}, checkbounds = false, linenumbers = false, @@ -584,13 +584,13 @@ function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1], states = LazyState(), lhsname=:du,rhsnames=[Symbol("MTK$i") for i in 1:length(args)]) O = value(O) - if (issym(O) || issym(operation(O))) || (istree(O) && operation(O) == getindex) + if (issym(O) || issym(operation(O))) || (iscall(O) && operation(O) == getindex) (j,i) = get(varnumbercache, O, (nothing, nothing)) if !isnothing(j) return i==0 ? :($(rhsnames[j])) : :($(rhsnames[j])[$(i+offset)]) end end - if istree(O) + if iscall(O) if operation(O) === getindex args = arguments(O) Expr(:ref, toexpr(args[1], states), toexpr.(args[2:end] .+ offset, (states,))...) diff --git a/src/complex.jl b/src/complex.jl index d1e7b4769..27124d3b8 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -18,10 +18,10 @@ function wrapper_type(::Type{Complex{T}}) where T end symtype(a::ComplexTerm{T}) where T = Complex{T} -istree(a::ComplexTerm) = true +iscall(a::ComplexTerm) = true operation(a::ComplexTerm{T}) where T = Complex{T} arguments(a::ComplexTerm) = [a.re, a.im] -metadata(a::ComplexTerm) = a.re.metadata +metadata(a::ComplexTerm) = metadata(a.re) function similarterm(t::ComplexTerm, f, args, symtype; metadata=nothing) if f <: Complex @@ -41,8 +41,8 @@ function Base.show(io::IO, a::Complex{Num}) rr = unwrap(real(a)) ii = unwrap(imag(a)) - if istree(rr) && (operation(rr) === real) && - istree(ii) && (operation(ii) === imag) && + if iscall(rr) && (operation(rr) === real) && + iscall(ii) && (operation(ii) === imag) && isequal(arguments(rr)[1], arguments(ii)[1]) return print(io, arguments(rr)[1]) diff --git a/src/diff.jl b/src/diff.jl index 7e91d19d8..aeb2cc78e 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -46,7 +46,7 @@ end (D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) SymbolicUtils.promote_symtype(::Differential, T) = T -is_derivative(x) = istree(x) ? operation(x) isa Differential : false +is_derivative(x) = iscall(x) ? operation(x) isa Differential : false Base.:*(D1, D2::Differential) = D1 ∘ D2 Base.:*(D1::Differential, D2) = D1 ∘ D2 @@ -59,7 +59,7 @@ Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x) Base.hash(D::Differential, u::UInt) = hash(D.x, xor(u, 0xdddddddddddddddd)) _isfalse(occ::Bool) = occ === false -_isfalse(occ::Symbolic) = istree(occ) && _isfalse(operation(occ)) +_isfalse(occ::Symbolic) = iscall(occ) && _isfalse(operation(occ)) function occursin_info(x, expr, fail = true) if symtype(expr) <: AbstractArray @@ -72,8 +72,8 @@ function occursin_info(x, expr, fail = true) # Allow scalarized expressions function is_scalar_indexed(ex) - (istree(ex) && operation(ex) == getindex && !(symtype(ex) <: AbstractArray)) || - (istree(ex) && (issym(operation(ex)) || istree(operation(ex))) && + (iscall(ex) && operation(ex) == getindex && !(symtype(ex) <: AbstractArray)) || + (iscall(ex) && (issym(operation(ex)) || iscall(operation(ex))) && is_scalar_indexed(operation(ex))) end @@ -93,7 +93,7 @@ function occursin_info(x, expr, fail = true) return false end - !istree(expr) && return isequal(x, expr) + !iscall(expr) && return isequal(x, expr) if isequal(x, expr) true else @@ -132,7 +132,7 @@ An internal function that contains the logic for [`hasderiv`](@ref) and [`hasdif Return true if `O` contains a term with `Operator` `op`. """ function recursive_hasoperator(op, O) - istree(O) || return false + iscall(O) || return false if operation(O) isa op return true else @@ -154,14 +154,14 @@ $(SIGNATURES) Expands derivatives within a symbolic expression `O`. This function recursively traverses a symbolic expression, applying the chain rule -and other derivative rules to expand any derivatives it encounters. +and other derivative rules to expand any derivatives it encounters. # Arguments - `O::Symbolic`: The symbolic expression to expand. -- `simplify::Bool=false`: Whether to simplify the resulting expression using +- `simplify::Bool=false`: Whether to simplify the resulting expression using [`SymbolicUtils.simplify`](@ref). -- `occurrences=nothing`: Information about the occurrences of the independent - variable in the argument of the derivative. This is used internally for +- `occurrences=nothing`: Information about the occurrences of the independent + variable in the argument of the derivative. This is used internally for optimization purposes. # Examples @@ -179,7 +179,7 @@ julia> dfx=expand_derivatives(Dx(f)) ``` """ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) - if istree(O) && isa(operation(O), Differential) + if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) arg = expand_derivatives(arg, false) @@ -192,7 +192,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) D = operation(O) - if !istree(arg) + if !iscall(arg) return D(arg) # Cannot expand elseif (op = operation(arg); issym(op)) inner_args = arguments(arg) @@ -218,7 +218,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) else inner = expand_derivatives(D(arguments(arg)[1]), false) # if the inner expression is not expandable either, return - if istree(inner) && operation(inner) isa Differential + if iscall(inner) && operation(inner) isa Differential return D(arg) else return expand_derivatives(op(inner), simplify) @@ -230,12 +230,12 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) a, b = DomainSets.endpoints(domain) c = 0 inner_function = expand_derivatives(arguments(arg)[1]) - if istree(value(a)) + if iscall(value(a)) t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) t2 = D(a) c -= t1*t2 end - if istree(value(b)) + if iscall(value(b)) t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b))) t2 = D(b) c += t1*t2 @@ -283,7 +283,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing) x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...) return simplify ? SymbolicUtils.simplify(x) : x end - elseif istree(O) && isa(operation(O), Integral) + elseif iscall(O) && isa(operation(O), Integral) return operation(O)(expand_derivatives(arguments(O)[1])) elseif !hasderiv(O) return O @@ -342,7 +342,7 @@ sin(x) """ derivative_idx(O::Any, ::Any) = 0 function derivative_idx(O::Symbolic, idx) - istree(O) ? derivative(operation(O), (arguments(O)...,), Val(idx)) : 0 + iscall(O) ? derivative(operation(O), (arguments(O)...,), Val(idx)) : 0 end # Indicate that no derivative is defined. diff --git a/src/domains.jl b/src/domains.jl index 965213c38..02d23f3bb 100644 --- a/src/domains.jl +++ b/src/domains.jl @@ -17,10 +17,10 @@ Base.:∈(variable::DomainedVar,domain::NTuple{2,Real}) = VarDomainPairing(varia # Multiple variables Base.:∈(variables::NTuple{N,DomainedVar},domain::Domain) where N = VarDomainPairing(value.(variables),domain) -function infimum(d::AbstractInterval{T}) where T <: Num +function infimum(d::AbstractInterval{<:Num}) leftendpoint(d) end -function supremum(d::AbstractInterval{T}) where T <: Num +function supremum(d::AbstractInterval{<:Num}) rightendpoint(d) end diff --git a/src/extra_functions.jl b/src/extra_functions.jl index 040f14f72..53fc0d2f7 100644 --- a/src/extra_functions.jl +++ b/src/extra_functions.jl @@ -4,7 +4,7 @@ function _binomial(nothing, n, k) args = [n, k] unwrapped_args = map(Symbolics.unwrap, args) res = if !(any((x->begin - SymbolicUtils.issym(x) || SymbolicUtils.istree(x) + SymbolicUtils.issym(x) || SymbolicUtils.iscall(x) end), unwrapped_args)) Base.binomial(unwrapped_args...) else diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index 8fb2b578b..3174dff8f 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -3,7 +3,7 @@ prettify_expr(f::Function) = nameof(f) prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...) function cleanup_exprs(ex) - return postwalk(x -> istree(x) && length(arguments(x)) == 0 ? operation(x) : x, ex) + return postwalk(x -> iscall(x) && length(arguments(x)) == 0 ? operation(x) : x, ex) end function latexify_derivatives(ex) @@ -58,7 +58,7 @@ end @latexrecipe function f(z::Complex{Num}) env --> :equation cdot --> false - + iszero(z.im) && return :($(recipe(z.re))) iszero(z.re) && return :($(recipe(z.im)) * $im) return :($(recipe(z.re)) + $(recipe(z.im)) * $im) @@ -142,7 +142,7 @@ function _toexpr(O) base = term.base pow = term.exp - isneg = (pow isa Number && pow < 0) || (istree(pow) && operation(pow) === (-) && length(arguments(pow)) == 1) + isneg = (pow isa Number && pow < 0) || (iscall(pow) && operation(pow) === (-) && length(arguments(pow)) == 1) if !isneg if _isone(pow) pushfirst!(numer, _toexpr(base)) @@ -179,7 +179,7 @@ function _toexpr(O) end end issym(O) && return nameof(O) - !istree(O) && return O + !iscall(O) && return O op = operation(O) args = arguments(O) @@ -233,7 +233,7 @@ _toexpr(eqs::AbstractArray) = map(eq->_toexpr(eq), eqs) _toexpr(x::Num) = _toexpr(value(x)) function getindex_to_symbol(t) - @assert istree(t) && operation(t) === getindex && symtype(arguments(t)[1]) <: AbstractArray + @assert iscall(t) && operation(t) === getindex && symtype(arguments(t)[1]) <: AbstractArray args = arguments(t) idxs = args[2:end] try @@ -258,4 +258,3 @@ function diffdenom(e) e end end - diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index f13bc8854..a3d9a706f 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -1,5 +1,5 @@ function nterms(t) - if istree(t) + if iscall(t) return reduce(+, map(nterms, arguments(t)), init=0) else return 1 @@ -258,7 +258,7 @@ function _linear_expansion(t::Equation, x) end trivial_linear_expansion(t, x) = isequal(t, x) ? (1, 0, true) : (0, t, true) -is_expansion_leaf(t) = !istree(t) || (operation(t) isa Operator) +is_expansion_leaf(t) = !iscall(t) || (operation(t) isa Operator) @noinline expansion_check(op) = op isa Operator && error("The operation is an Operator. This should never happen.") function _linear_expansion(t, x) t = value(t) diff --git a/src/register.jl b/src/register.jl index 51c4152ac..05ddf0d7e 100644 --- a/src/register.jl +++ b/src/register.jl @@ -32,7 +32,7 @@ macro register_symbolic(expr, define_promotion = true, Ts = :([])) fexpr = :(Symbolics.@wrapped function $f($(args′...)) args = [$(argnames...),] unwrapped_args = map($unwrap, args) - res = if !any(x->$issym(x) || $istree(x), unwrapped_args) + res = if !any(x->$issym(x) || $iscall(x), unwrapped_args) $f(unwrapped_args...) # partial-eval if all args are unwrapped else $Term{$ret_type}($f, unwrapped_args) @@ -94,7 +94,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs @wrapped function $f($(args′...)) args = [$(argnames...),] unwrapped_args = map($unwrap, args) - res = if !any(x->$issym(x) || $istree(x), unwrapped_args) + res = if !any(x->$issym(x) || $iscall(x), unwrapped_args) $f(unwrapped_args...) # partial-eval if all args are unwrapped elseif $ret_type == nothing || ($ret_type <: AbstractArray) $array_term($(Expr(:parameters, [Expr(:kw, k, v) for (k, v) in defs]...)), $f, unwrapped_args...) diff --git a/src/rewrite-helpers.jl b/src/rewrite-helpers.jl index b2a7d0249..cfb917d50 100644 --- a/src/rewrite-helpers.jl +++ b/src/rewrite-helpers.jl @@ -73,7 +73,7 @@ function _occursin(r, y) end end - if istree(y) + if iscall(y) return r(operation(y)) || any(y->_occursin(r, y), arguments(y)) else @@ -116,7 +116,7 @@ function filterchildren!(r::Any, y, acc) end end - if istree(y) + if iscall(y) if isequal(r, operation(y)) push!(acc, operation(y)) elseif r isa Function && r(operation(y)) diff --git a/src/semipoly.jl b/src/semipoly.jl index 2593ae4ae..46b635ba3 100644 --- a/src/semipoly.jl +++ b/src/semipoly.jl @@ -23,7 +23,7 @@ function Base.:+(a::SemiMonomial, b::SemiMonomial) Term(+, [a, b]) end function Base.:+(m::SemiMonomial, t) - if istree(t) && operation(t) == (+) + if iscall(t) && operation(t) == (+) return Term(+, [unsorted_arguments(t); m]) end Term(+, [m, t]) @@ -36,7 +36,7 @@ function Base.:*(a::SemiMonomial, b::SemiMonomial) end Base.:*(m::SemiMonomial, n::Number) = SemiMonomial(m.p, m.coeff * n) function Base.:*(m::SemiMonomial, t::Symbolic) - if istree(t) + if iscall(t) op = operation(t) if op == (+) args = collect(all_terms(t)) @@ -70,7 +70,7 @@ function pdegrees(x) dict = pdegrees(x.base) degrees = map(degree -> degree * x.exp, values(dict)) Dict(keys(dict) .=> degrees) - elseif issym(x) || istree(x) + elseif issym(x) || iscall(x) return Dict(x=>1) elseif x isa Number return Dict() @@ -133,7 +133,7 @@ issym(::SemiMonomial) = true Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff) -isop(x, op) = istree(x) && operation(x) === op +isop(x, op) = iscall(x) && operation(x) === op isop(op) = Base.Fix2(isop, op) bareterm(x, f, args; kw...) = Term{symtype(x)}(f, args) @@ -158,7 +158,7 @@ end function semipolyform_terms(expr, vars) expr = mark_and_exponentiate(expr, vars) - if istree(expr) && operation(expr) == (+) + if iscall(expr) && operation(expr) == (+) args = collect(all_terms(expr)) return args elseif isreal(expr) && iszero(real(expr)) # when `expr` is just a 0 @@ -177,7 +177,7 @@ Return true if `expr` contains any variables in `vars`. function has_vars(expr, vars)::Bool if expr in vars return true - elseif istree(expr) + elseif iscall(expr) for arg in unsorted_arguments(expr) if has_vars(arg, vars) return true @@ -190,7 +190,7 @@ end function mark_vars(expr, vars) if expr in vars return SemiMonomial(expr, 1) - elseif !istree(expr) + elseif !iscall(expr) return SemiMonomial(1, expr) end op = operation(expr) @@ -403,7 +403,7 @@ end ## Utilities -all_terms(x) = istree(x) && operation(x) == (+) ? collect(Iterators.flatten(map(all_terms, unsorted_arguments(x)))) : (x,) +all_terms(x) = iscall(x) && operation(x) == (+) ? collect(Iterators.flatten(map(all_terms, unsorted_arguments(x)))) : (x,) function unwrap_sp(m::SemiMonomial) degree_dict = pdegrees(m.p) @@ -424,7 +424,7 @@ function unwrap_sp(m::SemiMonomial) end function unwrap_sp(x) x = unwrap(x) - istree(x) ? similarterm(x, operation(x), map(unwrap_sp, unsorted_arguments(x))) : x + iscall(x) ? similarterm(x, operation(x), map(unwrap_sp, unsorted_arguments(x))) : x end function cautious_sum(nls) diff --git a/src/solver.jl b/src/solver.jl index 020c3e58b..420670083 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -81,7 +81,7 @@ function get_parts_list(a, b, a_list = Vector{Any}(), b_list = Vector{Any}()) if SymbolicUtils.issym(a) push!(a_list, a) push!(b_list, b) - elseif istree(a) && istree(b) && isequal(operation(a), operation(b)) + elseif iscall(a) && iscall(b) && isequal(operation(a), operation(b)) a_args = arguments(a) b_args = arguments(b) @@ -160,7 +160,7 @@ end function replace_term(expr, dic::Dict) if SymbolicUtils.issym(expr) && haskey(dic, expr) return dic[expr] - elseif istree(expr) + elseif iscall(expr) args = Any[] for arg in arguments(expr) @@ -202,9 +202,9 @@ end function expr_similar(ref_expr, expr, check_matches = true) SymbolicUtils.issym(ref_expr) && return true - SymbolicUtils.issym(expr) && istree(ref_expr) && return false + SymbolicUtils.issym(expr) && iscall(ref_expr) && return false - if istree(ref_expr) + if iscall(ref_expr) ref_args = arguments(ref_expr) ref_len = length(ref_args) ref_op = operation(ref_expr) @@ -249,12 +249,12 @@ function expr_similar(ref_expr, expr, check_matches = true) end function get_base(expr) - (!istree(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr") + (!iscall(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr") return arguments(expr)[1] end function get_exp(expr) - (!istree(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr") + (!iscall(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr") return arguments(expr)[2] end @@ -270,7 +270,7 @@ function solve_single_eq_unchecked( while (true) oldState = eq - if (istree(eq.lhs)) + if (iscall(eq.lhs)) potential_solution = solve_quadratic(eq, var, single_solution) if potential_solution isa Equation @@ -366,7 +366,7 @@ example move_to_other_side(x+a~z,x) returns x~z-a =# function move_to_other_side(eq::Equation, var) - !istree(eq.lhs) && return eq#make sure left side is tree form + !iscall(eq.lhs) && return eq#make sure left side is tree form op = operation(eq.lhs) @@ -421,14 +421,14 @@ function special_strategy(eq::Equation, var) end - !istree(eq.lhs) && return eq#make sure left side is tree form + !iscall(eq.lhs) && return eq#make sure left side is tree form op = operation(eq.lhs) elements = arguments(eq.lhs) if (op == +) && length(elements) == 2 && - sum(istree.(elements)) == length(elements) && + sum(iscall.(elements)) == length(elements) && isequal(operation.(elements), [sqrt for el = 1:length(elements)]) #check for sqrt(a)+sqrt(b)=c form , to solve this sqrt(a)+sqrt(b)=c -> 4*a*b-full_expand((c^2-b-a)^2)=0 then solve using quadratics #grab values @@ -445,10 +445,10 @@ function special_strategy(eq::Equation, var) elseif (op == +) && isequal(eq.rhs, 0) && length(elements) == 2 && - sum(istree.(elements)) == length(elements) && + sum(iscall.(elements)) == length(elements) && length(arguments(elements[1])) == 2 && isequal(arguments(elements[1])[1], -1) && - istree(arguments(elements[1])[2]) && + iscall(arguments(elements[1])[2]) && operation(elements[2]) == operation(arguments(elements[1])[2])#-f(y)+f(x)=0 -> x-y=0 x = arguments(elements[2])[1] @@ -473,16 +473,16 @@ function reduce_root(a) a = term(^, a.base, a.exp) end - if istree(a) && (operation(a) == sqrt) + if iscall(a) && (operation(a) == sqrt) a = SymbolicUtils.Pow(arguments(a)[1], 1 // 2) - elseif istree(a) && + elseif iscall(a) && (operation(a) == ^) && isequal(arguments(a)[2], 1 // 2) && !(arguments(a)[1] isa Number) a = term(sqrt, arguments(a)[1]) end - if istree(a) && + if iscall(a) && (operation(a) == ^) && arguments(a)[2] isa Rational && isequal((arguments(a)[2]).num, 1) @@ -531,7 +531,7 @@ if in quadratic form returns solutions =# function solve_quadratic(eq::Equation, var, single_solution) - !istree(eq.lhs) && return eq#make sure left side is tree form + !iscall(eq.lhs) && return eq#make sure left side is tree form op = operation(eq.lhs) @@ -579,7 +579,7 @@ end #reverse certain functions function inverse_funcs(eq::Equation, var) - !istree(eq.lhs) && return eq#make sure left side is tree form + !iscall(eq.lhs) && return eq#make sure left side is tree form op = operation(eq.lhs) #reverse functions @@ -611,7 +611,7 @@ end #solves for powers function reverse_powers(eq::Equation, var, single_solution) - !istree(eq.lhs) && return eq#make sure left side is tree form + !iscall(eq.lhs) && return eq#make sure left side is tree form op = operation(eq.lhs) if (op == ^) diff --git a/src/utils.jl b/src/utils.jl index 38c861877..b6988f6ea 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -49,10 +49,10 @@ get_variables!(vars, e::Num, varlist=nothing) = get_variables!(vars, value(e), v get_variables!(vars, e, varlist=nothing) = vars function is_singleton(e) - if istree(e) + if iscall(e) op = operation(e) op === getindex && return true - istree(op) && return is_singleton(op) # recurse to reach getindex for array element variables + iscall(op) && return is_singleton(op) # recurse to reach getindex for array element variables return issym(op) else return issym(e) @@ -100,7 +100,7 @@ var"z(t)[1]ˍt" ``` """ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothing) - istree(O) || return O + iscall(O) || return O if is_derivative(O) ds = "" while is_derivative(O) @@ -119,7 +119,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi oldop = operation(O) if issym(oldop) opname = string(nameof(oldop)) - elseif istree(oldop) && operation(oldop) === getindex + elseif iscall(oldop) && operation(oldop) === getindex opname = string(nameof(arguments(oldop)[1])) args = arguments(O) elseif oldop == getindex @@ -160,7 +160,7 @@ julia> Symbolics.tosymbol(z; escape=false) function tosymbol(t; states=nothing, escape=true) if issym(t) return nameof(t) - elseif istree(t) + elseif iscall(t) if issym(operation(t)) if states !== nothing && !(t in states) return nameof(operation(t)) @@ -224,7 +224,7 @@ function var_from_nested_derivative(x,i=0) x = unwrap(x) if issym(x) (x, i) - elseif istree(x) + elseif iscall(x) operation(x) isa Differential ? var_from_nested_derivative(first(arguments(x)), i + 1) : (x, i) else diff --git a/src/variable.jl b/src/variable.jl index d2b461d80..8722fbb04 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -236,7 +236,7 @@ struct CallWithMetadata{T,M} <: Symbolic{T} metadata::M end -for f in [:istree, :operation, :arguments] +for f in [:iscall, :operation, :arguments] @eval SymbolicUtils.$f(x::CallWithMetadata) = $f(x.f) end @@ -383,10 +383,10 @@ _getname(x, _) = nameof(x) _getname(x::Symbol, _) = x function _getname(x::Symbolic, val) issym(x) && return nameof(x) - if istree(x) && issym(operation(x)) + if iscall(x) && issym(operation(x)) return nameof(operation(x)) end - if !hasmetadata(x, Symbolics.GetindexParent) && istree(x) && operation(x) == getindex + if !hasmetadata(x, Symbolics.GetindexParent) && iscall(x) && operation(x) == getindex return _getname(arguments(x)[1], val) end ss = getsource(x, nothing) @@ -405,7 +405,7 @@ SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x)) function SymbolicIndexingInterface.hasname(x::Symbolic) - issym(x) || !istree(x) || istree(x) && (issym(operation(x)) || operation(x) == getindex) + issym(x) || !iscall(x) || iscall(x) && (issym(operation(x)) || operation(x) == getindex) end SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) @@ -474,7 +474,7 @@ function fast_substitute(expr, subs; operator = Nothing) if (_val = get(subs, expr, nothing)) !== nothing return _val end - istree(expr) || return expr + iscall(expr) || return expr op = fast_substitute(operation(expr), subs; operator) args = SymbolicUtils.unsorted_arguments(expr) if !(op isa operator) @@ -502,7 +502,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) expr = fast_substitute(expr, ai => bi; operator) end end - istree(expr) || return expr + iscall(expr) || return expr op = fast_substitute(operation(expr), pair; operator) args = SymbolicUtils.unsorted_arguments(expr) if !(op isa operator) @@ -528,7 +528,7 @@ function getparent(x, val=_fail) if maybe_parent !== nothing return maybe_parent else - if istree(x) && operation(x) === getindex + if iscall(x) && operation(x) === getindex return arguments(x)[1] end end @@ -640,9 +640,9 @@ function rename(x::Symbolic, name) xx = @set! x.name = name xx = rename_metadata(x, xx, name) symtype(xx) <: AbstractArray ? rename_getindex_source(xx) : xx - elseif istree(x) && operation(x) === getindex + elseif iscall(x) && operation(x) === getindex rename(arguments(x)[1], name)[arguments(x)[2:end]...] - elseif istree(x) && symtype(operation(x)) <: FnType || operation(x) isa CallWithMetadata + elseif iscall(x) && symtype(operation(x)) <: FnType || operation(x) isa CallWithMetadata xx = @set x.f = rename(operation(x), name) @set! xx.hash = Ref{UInt}(0) return rename_metadata(x, xx, name) diff --git a/test/complex.jl b/test/complex.jl index 47a10b9c6..1b4b7646f 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -34,4 +34,6 @@ end @test_nowarn substitute(z1, z=>1.0im) @test metadata(z1) == unwrap(z1.im).metadata @test metadata(z1) == unwrap(z1.re).metadata + z2 = 1.0 + z*im + @test isnothing(metadata(unwrap(z1.re))) end diff --git a/test/solver.jl b/test/solver.jl index 1a510c7f7..7a976955b 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -11,7 +11,7 @@ using LambertW return !isinteger(expr) && expr != float(pi) && expr != exp(1.0) elseif expr isa Equation return hasFloat(expr.lhs) || hasFloat(expr.rhs) - elseif istree(expr) + elseif iscall(expr) elements = arguments(expr) for element in elements if hasFloat(element)