diff --git a/Project.toml b/Project.toml index 7b7f8373..af4c137a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "0.13.1" +version = "0.14.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -24,6 +24,7 @@ DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils" DynamicExpressionsZygoteExt = "Zygote" [compat] +Aqua = "0.7" Compat = "3.37, 4" LoopVectorization = "0.12" MacroTools = "0.4, 0.5" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 5cda06a9..9b9e8dc2 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,6 +1,11 @@ using DynamicExpressions, BenchmarkTools, Random using DynamicExpressions.EquationUtilsModule: is_constant using Zygote +if PACKAGE_VERSION < v"0.14.0" + @eval using DynamicExpressions: Node as GraphNode +else + @eval using DynamicExpressions: GraphNode +end include("benchmark_utils.jl") @@ -66,13 +71,15 @@ end # These macros make the benchmarks work on older versions: #! format: off -@generated function _convert(::Type{N}, t; preserve_sharing) where {N<:Node} +@generated function _convert(::Type{N}, t; preserve_sharing) where {N} PACKAGE_VERSION < v"0.7.0" && return :(convert(N, t)) - return :(convert(N, t; preserve_sharing=preserve_sharing)) + PACKAGE_VERSION < v"0.14.0" && return :(convert(N, t; preserve_sharing=preserve_sharing)) + return :(convert(N, t)) # Assume type used to infer sharing end @generated function _copy_node(t; preserve_sharing) PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing)) - return :(copy_node(t; preserve_sharing=preserve_sharing)) + PACKAGE_VERSION < v"0.14.0" && return :(copy_node(t; preserve_sharing=preserve_sharing)) + return :(copy_node(t)) # Assume type used to infer sharing end @generated function get_set_constants!(tree) !(@isdefined set_constants!) && return :(set_constants(tree, get_constants(tree))) @@ -101,13 +108,36 @@ function benchmark_utilities() :index_constants, :string_tree, ) + has_both_modes = [:copy, :convert] + if PACKAGE_VERSION >= v"0.14.0" + append!( + has_both_modes, + [ + :simplify_tree, + :count_nodes, + :count_constants, + :get_set_constants!, + :index_constants, + :string_tree, + ], + ) + end operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp]) for func_k in all_funcs suite[func_k] = let s = BenchmarkGroup() - for k in (:break_sharing, :preserve_sharing) - has_both_modes = func_k in (:copy, :convert) - k == :preserve_sharing && !has_both_modes && continue + for k in ( + if func_k in has_both_modes + [:break_sharing, :preserve_sharing] + else + [:break_sharing] + end + ) + preprocess = if k == :preserve_sharing && PACKAGE_VERSION >= v"0.14.0" + tree -> GraphNode(tree) + else + identity + end f = if func_k == :copy tree -> _copy_node(tree; preserve_sharing=(k == :preserve_sharing)) @@ -132,12 +162,9 @@ function benchmark_utilities() setup=( ntrees=100; n=20; - trees=[gen_random_tree_fixed_size(n, $operators, 5, Float32) for _ in 1:ntrees] + trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees] ) ) - if !has_both_modes - s = s[k] - end #! format: on end s diff --git a/docs/src/types.md b/docs/src/types.md index eb418846..56c31bfa 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -48,16 +48,7 @@ Equations are specified as binary trees with the `Node` type, defined as follows: ```@docs -Node{T} -``` - -There are a variety of constructors for `Node` objects, including: - -```@docs -Node(::Type{T}; val=nothing, feature::Integer=nothing) where {T} -Node(op::Integer, l::Node) -Node(op::Integer, l::Node, r::Node) -Node(var_string::String) +Node ``` When you create an `Options` object, the operators @@ -69,23 +60,87 @@ When using these node constructors, types will automatically be promoted. You can convert the type of a node using `convert`: ```@docs -convert(::Type{Node{T1}}, tree::Node{T2}) where {T1, T2} +convert(::Type{AbstractExpressionNode{T1}}, tree::AbstractExpressionNode{T2}) where {T1, T2} ``` You can set a `tree` (in-place) with `set_node!`: ```@docs -set_node!(tree::Node{T}, new_tree::Node{T}) where {T} +set_node! ``` You can create a copy of a node with `copy_node`: ```@docs -copy_node(tree::Node) +copy_node +``` + +## Graph-Like Equations + +You can describe an equation as a *graph* rather than a tree +by using the `GraphNode` type: + +```@docs +GraphNode{T} +``` + +This makes it so you can have multiple parents for a given node, +and share parts of an expression. For example: + +```julia +julia> operators = OperatorEnum(; + binary_operators=[+, -, *], unary_operators=[cos, sin, exp] + ); + +julia> x1, x2 = GraphNode(feature=1), GraphNode(feature=2) +(x1, x2) + +julia> y = sin(x1) + 1.5 +sin(x1) + 1.5 + +julia> z = exp(y) + y +exp(sin(x1) + 1.5) + {(sin(x1) + 1.5)} +``` + +Here, the curly braces `{}` indicate that the node +is shared by another (or more) parent node. + +This means that we only need to change it once +to have changes propagate across the expression: + +```julia +julia> y.r.val *= 0.9 +1.35 + +julia> z +exp(sin(x1) + 1.35) + {(sin(x1) + 1.35)} +``` + +This also means there are fewer nodes to describe an expression: + +```julia +julia> length(z) +6 + +julia> length(convert(Node, z)) +10 +``` + +where we have converted the `GraphNode` to a `Node` type, +which breaks shared connections into separate nodes. + +## Abstract Types + +Both the `Node` and `GraphNode` types are subtypes of the abstract type: + +```@docs +AbstractExpressionNode{T} ``` -There is also an abstract type `AbstractNode` which is a supertype of `Node`: +which can be used to create additional expression-like types. +The supertype of this abstract type is the `AbstractNode` type, +which is more generic but does not have all of the same methods: ```@docs -AbstractNode +AbstractNode{T} ``` diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index fa6e20e2..d175adc9 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -1,7 +1,8 @@ module DynamicExpressionsSymbolicUtilsExt using SymbolicUtils -import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE +import DynamicExpressions.EquationModule: + AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node @@ -19,7 +20,9 @@ end subs_bad(x) = isgood(x) ? x : Inf function parse_tree_to_eqs( - tree::Node{T}, operators::AbstractOperatorEnum, index_functions::Bool=false + tree::AbstractExpressionNode{T}, + operators::AbstractOperatorEnum, + index_functions::Bool=false, ) where {T} if tree.degree == 0 # Return constant if needed @@ -27,6 +30,7 @@ function parse_tree_to_eqs( return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)")) end # Collect the next children + # TODO: Type instability! children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,) # Get the operation op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op] @@ -66,11 +70,12 @@ convert_to_function(x, operators::AbstractOperatorEnum) = x function split_eq( op, args, - operators::AbstractOperatorEnum; + operators::AbstractOperatorEnum, + ::Type{N}=Node; variable_names::Union{Array{String,1},Nothing}=nothing, # Deprecated: varMap=nothing, -) +) where {N<:AbstractExpressionNode} variable_names = deprecate_varmap(variable_names, varMap, :split_eq) !(op ∈ (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!")) if Symbol(op) == Symbol(sum) @@ -80,10 +85,10 @@ function split_eq( else ind = findoperation(op, operators.binops) end - return Node( + return constructorof(N)( ind, - convert(Node, args[1], operators; variable_names=variable_names), - convert(Node, op(args[2:end]...), operators; variable_names=variable_names), + convert(N, args[1], operators; variable_names=variable_names), + convert(N, op(args[2:end]...), operators; variable_names=variable_names), ) end @@ -96,7 +101,7 @@ end function Base.convert( ::typeof(SymbolicUtils.Symbolic), - tree::Node, + tree::AbstractExpressionNode, operators::AbstractOperatorEnum; variable_names::Union{Array{String,1},Nothing}=nothing, index_functions::Bool=false, @@ -109,20 +114,22 @@ function Base.convert( ) end -function Base.convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...) - return Node(; val=DEFAULT_NODE_TYPE(x)) +function Base.convert( + ::Type{N}, x::Number, operators::AbstractOperatorEnum; kws... +) where {N<:AbstractExpressionNode} + return constructorof(N)(; val=DEFAULT_NODE_TYPE(x)) end function Base.convert( - ::typeof(Node), + ::Type{N}, expr::SymbolicUtils.Symbolic, operators::AbstractOperatorEnum; variable_names::Union{Array{String,1},Nothing}=nothing, -) +) where {N<:AbstractExpressionNode} variable_names = deprecate_varmap(variable_names, nothing, :convert) if !SymbolicUtils.istree(expr) - variable_names === nothing && return Node(String(expr.name)) - return Node(String(expr.name), variable_names) + variable_names === nothing && return constructorof(N)(String(expr.name)) + return constructorof(N)(String(expr.name), variable_names) end # First, we remove integer powers: @@ -134,20 +141,21 @@ function Base.convert( op = convert_to_function(SymbolicUtils.operation(expr), operators) args = SymbolicUtils.arguments(expr) - length(args) > 2 && return split_eq(op, args, operators; variable_names=variable_names) + length(args) > 2 && + return split_eq(op, args, operators, N; variable_names=variable_names) ind = if length(args) == 2 findoperation(op, operators.binops) else findoperation(op, operators.unaops) end - return Node( - ind, map(x -> convert(Node, x, operators; variable_names=variable_names), args)... + return constructorof(N)( + ind, map(x -> convert(N, x, operators; variable_names=variable_names), args)... ) end """ - node_to_symbolic(tree::Node, operators::AbstractOperatorEnum; + node_to_symbolic(tree::AbstractExpressionNode, operators::AbstractOperatorEnum; variable_names::Union{Array{String, 1}, Nothing}=nothing, index_functions::Bool=false) @@ -156,17 +164,17 @@ will generate a symbolic equation in SymbolicUtils.jl format. ## Arguments -- `tree::Node`: The equation to convert. +- `tree::AbstractExpressionNode`: The equation to convert. - `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation. - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for each feature. Default is [x1, x2, x3, ...]. - `index_functions::Bool=false`: Whether to generate special names for the - operators, which then allows one to convert back to a `Node` format + operators, which then allows one to convert back to a `AbstractExpressionNode` format using `symbolic_to_node`. (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84). """ function node_to_symbolic( - tree::Node, + tree::AbstractExpressionNode, operators::AbstractOperatorEnum; variable_names::Union{Array{String,1},Nothing}=nothing, index_functions::Bool=false, @@ -192,13 +200,14 @@ end function symbolic_to_node( eqn::SymbolicUtils.Symbolic, - operators::AbstractOperatorEnum; + operators::AbstractOperatorEnum, + ::Type{N}=Node; variable_names::Union{Array{String,1},Nothing}=nothing, # Deprecated: varMap=nothing, -)::Node +) where {N<:AbstractExpressionNode} variable_names = deprecate_varmap(variable_names, varMap, :symbolic_to_node) - return convert(Node, eqn, operators; variable_names=variable_names) + return convert(N, eqn, operators; variable_names=variable_names) end function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index ab1a89d0..881bfc57 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -15,6 +15,8 @@ import PackageExtensionCompat: @require_extensions import Reexport: @reexport @reexport import .EquationModule: AbstractNode, + AbstractExpressionNode, + GraphNode, Node, string_tree, print_tree, @@ -22,6 +24,7 @@ import Reexport: @reexport set_node!, tree_mapreduce, filter_map +import .EquationModule: constructorof, preserve_sharing @reexport import .EquationUtilsModule: count_nodes, count_constants, @@ -38,7 +41,7 @@ import Reexport: @reexport @reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array @reexport import .EvaluateEquationDerivativeModule: eval_diff_tree_array, eval_grad_tree_array -@reexport import .SimplifyEquationModule: combine_operators, simplify_tree +@reexport import .SimplifyEquationModule: combine_operators, simplify_tree! @reexport import .EvaluationHelpersModule @reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node diff --git a/src/Equation.jl b/src/Equation.jl index 940f6eb9..d0bd4832 100644 --- a/src/Equation.jl +++ b/src/Equation.jl @@ -1,7 +1,7 @@ module EquationModule import ..OperatorEnumModule: AbstractOperatorEnum -import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap +import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 @@ -23,9 +23,32 @@ Abstract type for binary trees. Must have the following fields: """ abstract type AbstractNode end +""" + AbstractExpressionNode{T} <: AbstractNode + +Abstract type for nodes that represent an expression. +Along with the fields required for `AbstractNode`, +this additionally must have fields for: + +- `constant::Bool`: Whether the node is a constant. +- `val::T`: Value of the node. If `degree==0`, and `constant==true`, + this is the value of the constant. It has a type specified by the + overall type of the `Node` (e.g., `Float64`). +- `feature::UInt16`: Index of the feature to use in the + case of a feature node. Only used if `degree==0` and `constant==false`. + Only defined if `degree == 0 && constant == false`. +- `op::UInt8`: If `degree==1`, this is the index of the operator + in `operators.unaops`. If `degree==2`, this is the index of the + operator in `operators.binops`. In other words, this is an enum + of the operators, and is dependent on the specific `OperatorEnum` + object. Only defined if `degree >= 1` +``` +""" +abstract type AbstractExpressionNode{T} <: AbstractNode end + #! format: off """ - Node{T} + Node{T} <: AbstractExpressionNode{T} Node defines a symbolic expression stored in a binary tree. A single `Node` instance is one "node" of this tree, and @@ -53,8 +76,44 @@ nodes, you can evaluate or print a given expression. - `r::Node{T}`: Right child of the node. Only defined if `degree == 2`. Same type as the parent node. This is to be passed as the right argument to the binary operator. + +# Constructors + +## Leafs + + Node(; val=nothing, feature::Union{Integer,Nothing}=nothing) + Node{T}(; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T} + +Create a leaf node: either a constant, or a variable. + +- `::Type{T}`, optionally specify the type of the + node, if not already given by the type of + `val`. +- `val`, if you are specifying a constant, pass + the value of the constant here. +- `feature::Integer`, if you are specifying a variable, + pass the index of the variable here. + +You can also create a leaf node from variable names: + + Node(; var_string::String, variable_names::Array{String,1}) + Node{T}(; var_string::String, variable_names::Array{String,1}) where {T} + +## Unary operator + + Node(op::Integer, l::Node) + +Apply unary operator `op` (enumerating over the order given in `OperatorEnum`) +to `Node` `l`. + +## Binary operator + + Node(op::Integer, l::Node, r::Node) + +Apply binary operator `op` (enumerating over the order given in `OperatorEnum`) +to `Node`s `l` and `r`. """ -mutable struct Node{T} <: AbstractNode +mutable struct Node{T} <: AbstractExpressionNode{T} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. constant::Bool # false if variable val::Union{T,Nothing} # If is a constant, this stores the actual value @@ -74,114 +133,184 @@ mutable struct Node{T} <: AbstractNode Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r) end -################################################################################ -#! format: on - -include("base.jl") """ - Node([::Type{T}]; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T} + GraphNode{T} <: AbstractExpressionNode{T} -Create a leaf node: either a constant, or a variable. +Exactly the same as `Node{T}`, but with the assumption that some +nodes will be shared. All copies of this graph-like structure will +be performed with this assumption, to preserve structure of the graph. -# Arguments: +# Examples -- `::Type{T}`, optionally specify the type of the - node, if not already given by the type of - `val`. -- `val`, if you are specifying a constant, pass - the value of the constant here. -- `feature::Integer`, if you are specifying a variable, - pass the index of the variable here. +```julia +julia> operators = OperatorEnum(; + binary_operators=[+, -, *], unary_operators=[cos, sin] + ); + +julia> x = GraphNode(feature=1) +x1 + +julia> y = sin(x) + x +sin(x1) + {x1} + +julia> cos(y) * y +cos(sin(x1) + {x1}) * {(sin(x1) + {x1})} +``` + +Note how the `{}` indicates a node is shared, and this +is the same node as seen earlier in the string. + +This has the same constructors as `Node{T}`. Shared nodes +are created simply by using the same node in multiple places +when constructing or setting properties. """ -function Node(; - val::T1=nothing, feature::T2=nothing -)::Node where {T1,T2<:Union{Integer,Nothing}} - if T1 <: Nothing && T2 <: Nothing - error("You must specify either `val` or `feature` when creating a leaf node.") - elseif !(T1 <: Nothing || T2 <: Nothing) - error( - "You must specify either `val` or `feature` when creating a leaf node, not both.", - ) - elseif T2 <: Nothing - return Node(0, true, val) - else - return Node(DEFAULT_NODE_TYPE, 0, false, nothing, feature) - end +mutable struct GraphNode{T} <: AbstractExpressionNode{T} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::Union{T,Nothing} # If is a constant, this stores the actual value + # ------------------- (possibly undefined below) + feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops + l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2. + r::GraphNode{T} # Right child node. Only defined for degree=2. + + ################# + ## Constructors: + ################# + GraphNode(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v) + GraphNode(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v) + GraphNode(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f)) + GraphNode(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::GraphNode{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l) + GraphNode(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::GraphNode{_T}, r::GraphNode{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r) end -function Node( - ::Type{T}; val::T1=nothing, feature::T2=nothing -)::Node{T} where {T,T1,T2<:Union{Integer,Nothing}} - if T1 <: Nothing && T2 <: Nothing - error("You must specify either `val` or `feature` when creating a leaf node.") - elseif !(T1 <: Nothing || T2 <: Nothing) - error( - "You must specify either `val` or `feature` when creating a leaf node, not both.", - ) - elseif T2 <: Nothing + +################################################################################ +#! format: on + +constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper +constructorof(::Type{<:Node}) = Node +constructorof(::Type{<:GraphNode}) = GraphNode + +function with_type_parameters(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} + return constructorof(N){T} +end +with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T} +with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T} + +"""Trait declaring whether nodes share children or not.""" +preserve_sharing(::Type{<:AbstractNode}) = false +preserve_sharing(::Type{<:Node}) = false +preserve_sharing(::Type{<:GraphNode}) = true + +include("base.jl") + +function (::Type{N})( + ::Type{T}=Undefined; val::T1=nothing, feature::T2=nothing +) where {T,T1,T2<:Union{Integer,Nothing},N<:AbstractExpressionNode} + ((T1 <: Nothing) ⊻ (T2 <: Nothing)) || error( + "You must specify exactly one of `val` or `feature` when creating a leaf node." + ) + Tout = compute_value_output_type(N, T, T1) + if T2 <: Nothing if !(T1 <: T) # Only convert if not already in the type union. - val = convert(T, val) + val = convert(Tout, val) end - return Node(T, 0, true, val) + return constructorof(N)(Tout, 0, true, val) else - return Node(T, 0, false, nothing, feature) + return constructorof(N)(Tout, 0, false, nothing, feature) end end - -""" - Node(op::Integer, l::Node) - -Apply unary operator `op` (enumerating over the order given) to `Node` `l` -""" -Node(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l) - -""" - Node(op::Integer, l::Node, r::Node) - -Apply binary operator `op` (enumerating over the order given) to `Node`s `l` and `r` -""" -function Node(op::Integer, l::Node{T1}, r::Node{T2}) where {T1,T2} +function (::Type{N})( + op::Integer, l::AbstractExpressionNode{T} +) where {T,N<:AbstractExpressionNode} + @assert l isa N + return constructorof(N)(1, false, nothing, 0, op, l) +end +function (::Type{N})( + op::Integer, l::AbstractExpressionNode{T1}, r::AbstractExpressionNode{T2} +) where {T1,T2,N<:AbstractExpressionNode} + @assert l isa N && r isa N # Get highest type: if T1 != T2 T = promote_type(T1, T2) - l = convert(Node{T}, l) - r = convert(Node{T}, r) + # TODO: This might slow things down + l = convert(with_type_parameters(N, T), l) + r = convert(with_type_parameters(N, T), r) end - return Node(2, false, nothing, 0, op, l, r) + return constructorof(N)(2, false, nothing, 0, op, l, r) +end +function (::Type{N})(var_string::String) where {N<:AbstractExpressionNode} + Base.depwarn( + "Creating a node using a string is deprecated and will be removed in a future version.", + :string_tree, + ) + return N(; feature=parse(UInt16, var_string[2:end])) +end +function (::Type{N})( + var_string::String, variable_names::Array{String,1} +) where {N<:AbstractExpressionNode} + i = findfirst(==(var_string), variable_names)::Int + return N(; feature=i) end -""" - Node(var_string::String) - -Create a variable node, using the format `"x1"` to mean feature 1 -""" -Node(var_string::String) = Node(; feature=parse(UInt16, var_string[2:end])) +@inline function compute_value_output_type( + ::Type{N}, ::Type{T}, ::Type{T1} +) where {N<:AbstractExpressionNode,T,T1} + !(N isa UnionAll) && + T !== Undefined && + error( + "Ambiguous type for node. Please either use `Node{T}(; val, feature)` or `Node(T; val, feature)`.", + ) -""" - Node(var_string::String, variable_names::Array{String, 1}) + if T === Undefined && N isa UnionAll + if T1 <: Nothing + return DEFAULT_NODE_TYPE + else + return T1 + end + elseif T === Undefined + return eltype(N) + else + return T + end +end -Create a variable node, using a user-passed format -""" -function Node(var_string::String, variable_names::Array{String,1}) - return Node(; - feature=[ - i for (i, _variable) in enumerate(variable_names) if _variable == var_string - ][1]::Int, - ) +function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2} + return Node{promote_type(T1, T2)} +end +function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{Node{T2}}) where {T1,T2} + return GraphNode{promote_type(T1, T2)} +end +function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2} + return GraphNode{promote_type(T1, T2)} end +Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T +Base.eltype(::AbstractExpressionNode{T}) where {T} = T + +# TODO: Verify using this helps with garbage collection +create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N(; feature=zero(UInt16)) """ - set_node!(tree::Node{T}, new_tree::Node{T}) where {T} + set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T} Set every field of `tree` equal to the corresponding field of `new_tree`. """ -function set_node!(tree::Node{T}, new_tree::Node{T}) where {T} +function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNode) + # First, ensure we free some memory: + if new_tree.degree < 2 && tree.degree == 2 + tree.r = create_dummy_node(typeof(tree)) + end + if new_tree.degree < 1 && tree.degree >= 1 + tree.l = create_dummy_node(typeof(tree)) + end + tree.degree = new_tree.degree if new_tree.degree == 0 tree.constant = new_tree.constant if new_tree.constant - tree.val = new_tree.val::T + tree.val = new_tree.val::eltype(new_tree) else tree.feature = new_tree.feature end @@ -195,7 +324,7 @@ function set_node!(tree::Node{T}, new_tree::Node{T}) where {T} return nothing end -const OP_NAMES = Dict( +const OP_NAMES = Base.ImmutableDict( "safe_log" => "log", "safe_log2" => "log2", "safe_log10" => "log10", @@ -205,68 +334,106 @@ const OP_NAMES = Dict( "safe_pow" => "^", ) -get_op_name(op::String) = op -@generated function get_op_name(op::F) where {F} +function dispatch_op_name(::Val{2}, ::Nothing, idx)::Vector{Char} + return vcat(collect("binary_operator["), collect(string(idx)), [']']) +end +function dispatch_op_name(::Val{1}, ::Nothing, idx)::Vector{Char} + return vcat(collect("unary_operator["), collect(string(idx)), [']']) +end +function dispatch_op_name(::Val{2}, operators::AbstractOperatorEnum, idx)::Vector{Char} + return get_op_name(operators.binops[idx]) +end +function dispatch_op_name(::Val{1}, operators::AbstractOperatorEnum, idx)::Vector{Char} + return get_op_name(operators.unaops[idx]) +end + +@generated function get_op_name(op::F)::Vector{Char} where {F} try # Bit faster to just cache the name of the operator: op_s = string(F.instance) - out = get(OP_NAMES, op_s, op_s) + out = collect(get(OP_NAMES, op_s, op_s)) return :($out) catch end return quote op_s = string(op) - out = get(OP_NAMES, op_s, op_s) + out = collect(get(OP_NAMES, op_s, op_s)) return out end end -function string_op( - ::Val{2}, op::F, tree::Node, args...; bracketed, kws... -)::String where {F} - op_name = get_op_name(op) - if op_name in ["+", "-", "*", "/", "^", "×"] - l = string_tree(tree.l, args...; bracketed=false, kws...) - r = string_tree(tree.r, args...; bracketed=false, kws...) - if bracketed - return l * " " * op_name * " " * r - else - return "(" * l * " " * op_name * " " * r * ")" - end +@inline function strip_brackets(s::Vector{Char})::Vector{Char} + if first(s) == '(' && last(s) == ')' + return s[(begin + 1):(end - 1)] else - l = string_tree(tree.l, args...; bracketed=true, kws...) - r = string_tree(tree.r, args...; bracketed=true, kws...) - # return "$op_name($l, $r)" - return op_name * "(" * l * ", " * r * ")" + return s end end -function string_op( - ::Val{1}, op::F, tree::Node, args...; bracketed, kws... -)::String where {F} - op_name = get_op_name(op) - l = string_tree(tree.l, args...; bracketed=true, kws...) - return op_name * "(" * l * ")" -end -function string_constant(val, bracketed::Bool) - does_not_need_brackets = (typeof(val) <: Union{Real,AbstractArray}) - if does_not_need_brackets || bracketed - string(val) +# Can overload these for custom behavior: +needs_brackets(val::Real) = false +needs_brackets(val::AbstractArray) = false +needs_brackets(val::Complex) = true +needs_brackets(val) = true + +function string_constant(val) + if needs_brackets(val) + '(' * string(val) * ')' else - "(" * string(val) * ")" + string(val) end end function string_variable(feature, variable_names) if variable_names === nothing || feature > lastindex(variable_names) - return "x" * string(feature) + return 'x' * string(feature) else return variable_names[feature] end end +# Vector of chars is faster than strings, so we use that. +function combine_op_with_inputs(op, l, r)::Vector{Char} + if first(op) in ('+', '-', '*', '/', '^') + # "(l op r)" + out = ['('] + append!(out, l) + push!(out, ' ') + append!(out, op) + push!(out, ' ') + append!(out, r) + push!(out, ')') + else + # "op(l, r)" + out = copy(op) + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ',') + push!(out, ' ') + append!(out, strip_brackets(r)) + push!(out, ')') + return out + end +end +function combine_op_with_inputs(op, l) + # "op(l)" + out = copy(op) + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ')') + return out +end + """ - string_tree(tree::Node, operators::AbstractOperatorEnum[; bracketed, variable_names, f_variable, f_constant]) + string_tree( + tree::AbstractExpressionNode{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + f_variable::F1=string_variable, + f_constant::F2=string_constant, + variable_names::Union{Array{String,1},Nothing}=nothing, + # Deprecated + varMap=nothing, + )::String where {T,F1<:Function,F2<:Function} Convert an equation to a string. @@ -275,15 +442,13 @@ Convert an equation to a string. - `operators`: the operators used to define the tree # Keyword Arguments -- `bracketed`: (optional) whether to put brackets around the outside. -- `f_variable`: (optional) function to convert a variable to a string, of the form `(feature::UInt8, variable_names)`. -- `f_constant`: (optional) function to convert a constant to a string, of the form `(val, bracketed::Bool)` +- `f_variable`: (optional) function to convert a variable to a string, with arguments `(feature::UInt8, variable_names)`. +- `f_constant`: (optional) function to convert a constant to a string, with arguments `(val,)` - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: (optional) what variables to print for each feature. """ function string_tree( - tree::Node{T}, + tree::AbstractExpressionNode{T}, operators::Union{AbstractOperatorEnum,Nothing}=nothing; - bracketed::Bool=false, f_variable::F1=string_variable, f_constant::F2=string_constant, variable_names::Union{Array{String,1},Nothing}=nothing, @@ -291,71 +456,49 @@ function string_tree( varMap=nothing, )::String where {T,F1<:Function,F2<:Function} variable_names = deprecate_varmap(variable_names, varMap, :string_tree) - if tree.degree == 0 - if !tree.constant - return f_variable(tree.feature, variable_names) + raw_output = tree_mapreduce( + leaf -> if leaf.constant + collect(f_constant(leaf.val::T)) else - return f_constant(tree.val::T, bracketed) - end - elseif tree.degree == 1 - return string_op( - Val(1), - if operators === nothing - "unary_operator[" * string(tree.op) * "]" - else - operators.unaops[tree.op] - end, - tree, - operators; - bracketed, - f_variable, - f_constant, - variable_names, - ) - else - return string_op( - Val(2), - if operators === nothing - "binary_operator[" * string(tree.op) * "]" - else - operators.binops[tree.op] - end, - tree, - operators; - bracketed, - f_variable, - f_constant, - variable_names, - ) - end + collect(f_variable(leaf.feature, variable_names)) + end, + branch -> if branch.degree == 1 + dispatch_op_name(Val(1), operators, branch.op) + else + dispatch_op_name(Val(2), operators, branch.op) + end, + combine_op_with_inputs, + tree, + Vector{Char}; + f_on_shared=(c, is_shared) -> if is_shared + out = ['{'] + append!(out, c) + push!(out, '}') + out + else + c + end, + ) + return String(strip_brackets(raw_output)) end # Print an equation -function print_tree( - io::IO, - tree::Node, - operators::AbstractOperatorEnum; - f_variable::F1=string_variable, - f_constant::F2=string_constant, - variable_names::Union{Array{String,1},Nothing}=nothing, - # Deprecated - varMap=nothing, -) where {F1<:Function,F2<:Function} - variable_names = deprecate_varmap(variable_names, varMap, :print_tree) - return println(io, string_tree(tree, operators; f_variable, f_constant, variable_names)) -end - -function print_tree( - tree::Node, - operators::AbstractOperatorEnum; - f_variable::F1=string_variable, - f_constant::F2=string_constant, - variable_names::Union{Array{String,1},Nothing}=nothing, - # Deprecated - varMap=nothing, -) where {F1<:Function,F2<:Function} - variable_names = deprecate_varmap(variable_names, varMap, :print_tree) - return println(string_tree(tree, operators; f_variable, f_constant, variable_names)) +for io in ((), (:(io::IO),)) + @eval function print_tree( + $(io...), + tree::AbstractExpressionNode, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + f_variable::F1=string_variable, + f_constant::F2=string_constant, + variable_names::Union{Array{String,1},Nothing}=nothing, + # Deprecated + varMap=nothing, + ) where {F1<:Function,F2<:Function} + variable_names = deprecate_varmap(variable_names, varMap, :print_tree) + return println( + $(io...), string_tree(tree, operators; f_variable, f_constant, variable_names) + ) + end end end diff --git a/src/EquationUtils.jl b/src/EquationUtils.jl index 58e09117..218d6dce 100644 --- a/src/EquationUtils.jl +++ b/src/EquationUtils.jl @@ -1,15 +1,17 @@ module EquationUtilsModule import Compat: Returns -import ..EquationModule: AbstractNode, Node, copy_node, tree_mapreduce, any, filter_map - -""" - count_nodes(tree::AbstractNode)::Int - -Count the number of nodes in the tree. -""" -count_nodes(tree::AbstractNode) = tree_mapreduce(_ -> 1, +, tree) -# This code is given as an example. Normally we could just use sum(Returns(1), tree). +import ..EquationModule: + AbstractNode, + AbstractExpressionNode, + Node, + preserve_sharing, + constructorof, + copy_node, + count_nodes, + tree_mapreduce, + any, + filter_map """ count_depth(tree::AbstractNode)::Int @@ -17,73 +19,81 @@ count_nodes(tree::AbstractNode) = tree_mapreduce(_ -> 1, +, tree) Compute the max depth of the tree. """ function count_depth(tree::AbstractNode) - return tree_mapreduce(Returns(1), (p, child...) -> p + max(child...), tree) + return tree_mapreduce( + Returns(1), (p, child...) -> p + max(child...), tree, Int64; break_sharing=Val(true) + ) end """ - is_node_constant(tree::Node)::Bool + is_node_constant(tree::AbstractExpressionNode)::Bool Check if the current node in a tree is constant. """ -@inline is_node_constant(tree::Node) = tree.degree == 0 && tree.constant +@inline is_node_constant(tree::AbstractExpressionNode) = tree.degree == 0 && tree.constant """ - count_constants(tree::Node)::Int + count_constants(tree::AbstractExpressionNode)::Int Count the number of constants in a tree. """ -count_constants(tree::Node) = count(is_node_constant, tree) +function count_constants(tree::AbstractExpressionNode) + return tree_mapreduce( + node -> is_node_constant(node) ? 1 : 0, + +, + tree, + Int64; + f_on_shared=(c, is_shared) -> is_shared ? 0 : c, + ) +end """ - has_constants(tree::Node)::Bool + has_constants(tree::AbstractExpressionNode)::Bool Check if a tree has any constants. """ -has_constants(tree::Node) = any(is_node_constant, tree) +has_constants(tree::AbstractExpressionNode) = any(is_node_constant, tree) """ - has_operators(tree::Node)::Bool + has_operators(tree::AbstractExpressionNode)::Bool Check if a tree has any operators. """ -has_operators(tree::Node) = tree.degree != 0 +has_operators(tree::AbstractExpressionNode) = tree.degree != 0 """ - is_constant(tree::Node)::Bool + is_constant(tree::AbstractExpressionNode)::Bool Check if an expression is a constant numerical value, or whether it depends on input features. """ -is_constant(tree::Node) = all(t -> t.degree != 0 || t.constant, tree) +is_constant(tree::AbstractExpressionNode) = all(t -> t.degree != 0 || t.constant, tree) """ - get_constants(tree::Node{T})::Vector{T} where {T} + get_constants(tree::AbstractExpressionNode{T})::Vector{T} where {T} Get all the constants inside a tree, in depth-first order. The function `set_constants!` sets them in the same order, given the output of this function. """ -function get_constants(tree::Node{T}) where {T} +function get_constants(tree::AbstractExpressionNode{T}) where {T} return filter_map(is_node_constant, t -> (t.val::T), tree, T) end """ - set_constants!(tree::Node{T}, constants::AbstractVector{T}) where {T} + set_constants!(tree::AbstractExpressionNode{T}, constants::AbstractVector{T}) where {T} -Set the constants in a tree, in depth-first order. -The function `get_constants` gets them in the same order, +Set the constants in a tree, in depth-first order. The function +`get_constants` gets them in the same order. """ -function set_constants!(tree::Node{T}, constants::AbstractVector{T}) where {T} - if tree.degree == 0 - if tree.constant - tree.val = constants[1] +function set_constants!( + tree::AbstractExpressionNode{T}, constants::AbstractVector{T} +) where {T} + Base.require_one_based_indexing(constants) + i = Ref(0) + foreach(tree) do node + if node.degree == 0 && node.constant + @inbounds node.val = constants[i[] += 1] end - elseif tree.degree == 1 - set_constants!(tree.l, constants) - else - numberLeft = count_constants(tree.l) - set_constants!(tree.l, constants) - set_constants!(tree.r, @view constants[(numberLeft + 1):end]) end return nothing end @@ -91,43 +101,39 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -mutable struct NodeIndex - constant_index::UInt16 # Index of this constant (if a constant exists here) - l::NodeIndex - r::NodeIndex - - NodeIndex() = new() -end - -function index_constants(tree::Node)::NodeIndex - return index_constants(tree, UInt16(0)) -end - -function index_constants(tree::Node, left_index)::NodeIndex - index_tree = NodeIndex() - index_constants!(tree, index_tree, left_index) - return index_tree -end - -# Count how many constants to the left of this node, and put them in a tree -function index_constants!(tree::Node, index_tree::NodeIndex, left_index) - if tree.degree == 0 - if tree.constant - index_tree.constant_index = left_index + 1 - end - elseif tree.degree == 1 - index_tree.constant_index = count_constants(tree.l) - index_tree.l = NodeIndex() - index_constants!(tree.l, index_tree.l, left_index) - else - index_tree.l = NodeIndex() - index_tree.r = NodeIndex() - index_constants!(tree.l, index_tree.l, left_index) - index_tree.constant_index = count_constants(tree.l) - left_index_here = left_index + index_tree.constant_index - index_constants!(tree.r, index_tree.r, left_index_here) +struct NodeIndex{T} <: AbstractNode + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + val::T # If is a constant, this stores the actual value + # ------------------- (possibly undefined below) + l::NodeIndex{T} # Left child node. Only defined for degree=1 or degree=2. + r::NodeIndex{T} # Right child node. Only defined for degree=2. + + NodeIndex(::Type{_T}) where {_T} = new{_T}(0, zero(_T)) + NodeIndex(::Type{_T}, val) where {_T} = new{_T}(0, convert(_T, val)) + NodeIndex(::Type{_T}, l::NodeIndex) where {_T} = new{_T}(1, zero(_T), l) + function NodeIndex(::Type{_T}, l::NodeIndex, r::NodeIndex) where {_T} + return new{_T}(2, zero(_T), l, r) end - return nothing +end +# Sharing is never needed for NodeIndex, +# as we trace over the node we are indexing on. +preserve_sharing(::Type{<:NodeIndex}) = false + +function index_constants(tree::AbstractExpressionNode, ::Type{T}=UInt16) where {T} + # Essentially we copy the tree, replacing the values + # with indices + constant_index = Ref(T(0)) + return tree_mapreduce( + t -> if t.constant + NodeIndex(T, (constant_index[] += T(1))) + else + NodeIndex(T) + end, + t -> nothing, + (_, c...) -> NodeIndex(T, c...), + tree, + NodeIndex{T}; + ) end end diff --git a/src/EvaluateEquation.jl b/src/EvaluateEquation.jl index e594dfdb..3c7b9e49 100644 --- a/src/EvaluateEquation.jl +++ b/src/EvaluateEquation.jl @@ -1,7 +1,7 @@ module EvaluateEquationModule import LoopVectorization: @turbo, indices -import ..EquationModule: Node, string_tree +import ..EquationModule: AbstractExpressionNode, constructorof, string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array, fill_similar import ..EquationUtilsModule: is_constant @@ -23,7 +23,7 @@ macro return_on_nonfinite_array(array) end """ - eval_tree_array(tree::Node, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false) + eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false) Evaluate a binary tree (equation) over a given input data matrix. The operators contain all of the operators used. This function fuses doublets @@ -44,7 +44,7 @@ The bulk of the code is for optimizations and pre-emptive NaN/Inf checks, which speed up evaluation significantly. # Arguments -- `tree::Node`: The root node of the tree to evaluate. +- `tree::AbstractExpressionNode`: The root node of the tree to evaluate. - `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. - `operators::OperatorEnum`: The operators used in the tree. - `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation. @@ -57,7 +57,10 @@ which speed up evaluation significantly. to the equation. """ function eval_tree_array( - tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum; + turbo::Bool=false, )::Tuple{AbstractVector{T},Bool} where {T<:Number} if turbo @assert T in (Float32, Float64) @@ -70,17 +73,23 @@ function eval_tree_array( return result, finished end function eval_tree_array( - tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; turbo::Bool=false + tree::AbstractExpressionNode{T1}, + cX::AbstractMatrix{T2}, + operators::OperatorEnum; + turbo::Bool=false, ) where {T1<:Number,T2<:Number} T = promote_type(T1, T2) @warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)." - tree = convert(Node{T}, tree) + tree = convert(constructorof(typeof(tree)){T}, tree) cX = T.(cX) return eval_tree_array(tree, cX, operators; turbo=turbo) end function _eval_tree_array( - tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo} + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, + ::Val{turbo}, )::Tuple{AbstractVector{T},Bool} where {T<:Number,turbo} # First, we see if there are only constants in the tree - meaning # we can just return the constant result. @@ -160,7 +169,7 @@ function deg1_eval( end function deg0_eval( - tree::Node{T}, cX::AbstractMatrix{T} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T} )::Tuple{AbstractVector{T},Bool} where {T<:Number} if tree.constant return (fill_similar(tree.val::T, cX, axes(cX, 2)), true) @@ -170,7 +179,7 @@ function deg0_eval( end function deg1_l2_ll0_lr0_eval( - tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo} )::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val::T @@ -219,7 +228,7 @@ end # op(op2(x)) for x variable or constant function deg1_l1_ll0_eval( - tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo} )::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo} if tree.l.l.constant val_ll = tree.l.l.val::T @@ -243,7 +252,7 @@ end # op(x, y) for x and y variable/constant function deg2_l0_r0_eval( - tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo} )::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo} if tree.l.constant && tree.r.constant val_l = tree.l.val::T @@ -285,7 +294,11 @@ end # op(x, y) for x variable/constant, y arbitrary function deg2_l0_eval( - tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo} + tree::AbstractExpressionNode{T}, + cumulator::AbstractVector{T}, + cX::AbstractArray{T}, + op::F, + ::Val{turbo}, )::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo} if tree.l.constant val = tree.l.val::T @@ -306,7 +319,11 @@ end # op(x, y) for x arbitrary, y variable/constant function deg2_r0_eval( - tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo} + tree::AbstractExpressionNode{T}, + cumulator::AbstractVector{T}, + cX::AbstractArray{T}, + op::F, + ::Val{turbo}, )::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo} if tree.r.constant val = tree.r.val::T @@ -326,14 +343,14 @@ function deg2_r0_eval( end """ - _eval_constant_tree(tree::Node{T}, operators::OperatorEnum)::Tuple{T,Bool} where {T<:Number} + _eval_constant_tree(tree::AbstractExpressionNode{T}, operators::OperatorEnum)::Tuple{T,Bool} where {T<:Number} Evaluate a tree which is assumed to not contain any variable nodes. This gives better performance, as we do not need to perform computation over an entire array when the values are all the same. """ function _eval_constant_tree( - tree::Node{T}, operators::OperatorEnum + tree::AbstractExpressionNode{T}, operators::OperatorEnum )::Tuple{T,Bool} where {T<:Number} if tree.degree == 0 return deg0_eval_constant(tree) @@ -344,12 +361,14 @@ function _eval_constant_tree( end end -@inline function deg0_eval_constant(tree::Node{T})::Tuple{T,Bool} where {T<:Number} +@inline function deg0_eval_constant( + tree::AbstractExpressionNode{T} +)::Tuple{T,Bool} where {T<:Number} return tree.val::T, true end function deg1_eval_constant( - tree::Node{T}, op::F, operators::OperatorEnum + tree::AbstractExpressionNode{T}, op::F, operators::OperatorEnum )::Tuple{T,Bool} where {T<:Number,F} (cumulator, complete) = _eval_constant_tree(tree.l, operators) !complete && return zero(T), false @@ -358,7 +377,7 @@ function deg1_eval_constant( end function deg2_eval_constant( - tree::Node{T}, op::F, operators::OperatorEnum + tree::AbstractExpressionNode{T}, op::F, operators::OperatorEnum )::Tuple{T,Bool} where {T<:Number,F} (cumulator, complete) = _eval_constant_tree(tree.l, operators) !complete && return zero(T), false @@ -369,12 +388,12 @@ function deg2_eval_constant( end """ - differentiable_eval_tree_array(tree::Node, cX::AbstractMatrix, operators::OperatorEnum) + differentiable_eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix, operators::OperatorEnum) Evaluate an expression tree in a way that can be auto-differentiated. """ function differentiable_eval_tree_array( - tree::Node{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum + tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum )::Tuple{AbstractVector{T},Bool} where {T<:Number,T1} if tree.degree == 0 if tree.constant @@ -390,7 +409,7 @@ function differentiable_eval_tree_array( end function deg1_diff_eval( - tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum + tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum )::Tuple{AbstractVector{T},Bool} where {T<:Number,F,T1} (left, complete) = differentiable_eval_tree_array(tree.l, cX, operators) @return_on_false complete left @@ -400,7 +419,7 @@ function deg1_diff_eval( end function deg2_diff_eval( - tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum + tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum )::Tuple{AbstractVector{T},Bool} where {T<:Number,F,T1} (left, complete) = differentiable_eval_tree_array(tree.l, cX, operators) @return_on_false complete left @@ -412,7 +431,7 @@ function deg2_diff_eval( end """ - eval_tree_array(tree::Node, cX::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true) + eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true) Evaluate a generic binary tree (equation) over a given input data, whatever that input data may be. The `operators` enum contains all @@ -441,7 +460,7 @@ function eval(current_node) ``` # Arguments -- `tree::Node`: The root node of the tree to evaluate. +- `tree::AbstractExpressionNode`: The root node of the tree to evaluate. - `cX::AbstractArray`: The input data to evaluate the tree on. - `operators::GenericOperatorEnum`: The operators used in the tree. - `throw_errors::Bool=true`: Whether to throw errors @@ -460,7 +479,10 @@ function eval(current_node) that it was not defined for. """ function eval_tree_array( - tree::Node, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true + tree::AbstractExpressionNode, + cX::AbstractArray, + operators::GenericOperatorEnum; + throw_errors::Bool=true, ) !throw_errors && return _eval_tree_array_generic(tree, cX, operators, Val(false)) try @@ -480,7 +502,7 @@ function eval_tree_array( end function _eval_tree_array_generic( - tree::Node{T1}, + tree::AbstractExpressionNode{T1}, cX::AbstractArray{T2,N}, operators::GenericOperatorEnum, ::Val{throw_errors}, diff --git a/src/EvaluateEquationDerivative.jl b/src/EvaluateEquationDerivative.jl index f01b3693..6bc6b944 100644 --- a/src/EvaluateEquationDerivative.jl +++ b/src/EvaluateEquationDerivative.jl @@ -1,7 +1,7 @@ module EvaluateEquationDerivativeModule import LoopVectorization: indices, @turbo -import ..EquationModule: Node +import ..EquationModule: AbstractExpressionNode, constructorof import ..OperatorEnumModule: OperatorEnum import ..UtilsModule: @return_on_false2, @maybe_turbo, is_bad_array, fill_similar import ..EquationUtilsModule: count_constants, index_constants, NodeIndex @@ -18,7 +18,7 @@ function assert_autodiff_enabled(operators::OperatorEnum) end """ - eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Bool=false) + eval_diff_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Bool=false) Compute the forward derivative of an expression, using a similar structure and optimization to eval_tree_array. `direction` is the index of a particular @@ -27,7 +27,7 @@ respect to `x1`. # Arguments -- `tree::Node`: The expression tree to evaluate. +- `tree::AbstractExpressionNode`: The expression tree to evaluate. - `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point. - `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff` must be `true`. This is needed to create the derivative operations. @@ -40,7 +40,7 @@ respect to `x1`. the derivative, and whether the evaluation completed as normal (or encountered a nan or inf). """ function eval_diff_tree_array( - tree::Node{T}, + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; @@ -54,7 +54,7 @@ function eval_diff_tree_array( ) end function eval_diff_tree_array( - tree::Node{T1}, + tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum, direction::Integer; @@ -62,13 +62,13 @@ function eval_diff_tree_array( ) where {T1<:Number,T2<:Number} T = promote_type(T1, T2) @warn "Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2)." - tree = convert(Node{T}, tree) + tree = convert(constructorof(typeof(tree)){T}, tree) cX = T.(cX) return eval_diff_tree_array(tree, cX, operators, direction; turbo=turbo) end function _eval_diff_tree_array( - tree::Node{T}, + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer, @@ -102,7 +102,7 @@ function _eval_diff_tree_array( end function diff_deg0_eval( - tree::Node{T}, cX::AbstractMatrix{T}, direction::Integer + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, direction::Integer )::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number} const_part = deg0_eval(tree, cX)[1] derivative_part = if ((!tree.constant) && tree.feature == direction) @@ -114,7 +114,7 @@ function diff_deg0_eval( end function diff_deg1_eval( - tree::Node{T}, + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, diff_op::dF, @@ -139,7 +139,7 @@ function diff_deg1_eval( end function diff_deg2_eval( - tree::Node{T}, + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, diff_op::dF, @@ -169,7 +169,7 @@ function diff_deg2_eval( end """ - eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false, turbo::Bool=false) + eval_grad_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false, turbo::Bool=false) Compute the forward-mode derivative of an expression, using a similar structure and optimization to eval_tree_array. `variable` specifies whether @@ -178,7 +178,7 @@ to every constant in the expression. # Arguments -- `tree::Node{T}`: The expression tree to evaluate. +- `tree::AbstractExpressionNode{T}`: The expression tree to evaluate. - `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point. - `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff` must be `true`. This is needed to create the derivative operations. @@ -192,7 +192,7 @@ to every constant in the expression. the gradient, and whether the evaluation completed as normal (or encountered a nan or inf). """ function eval_grad_tree_array( - tree::Node{T}, + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false, @@ -200,22 +200,34 @@ function eval_grad_tree_array( )::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number} assert_autodiff_enabled(operators) n_gradients = variable ? size(cX, 1) : count_constants(tree) - index_tree = index_constants(tree, UInt16(0)) - return eval_grad_tree_array( - tree, - Val(n_gradients), - index_tree, - cX, - operators, - (variable ? Val(true) : Val(false)), - (turbo ? Val(true) : Val(false)), - ) + if variable + return eval_grad_tree_array( + tree, + Val(n_gradients), + nothing, + cX, + operators, + Val(true), + (turbo ? Val(true) : Val(false)), + ) + else + index_tree = index_constants(tree) + return eval_grad_tree_array( + tree, + Val(n_gradients), + index_tree, + cX, + operators, + Val(false), + (turbo ? Val(true) : Val(false)), + ) + end end function eval_grad_tree_array( - tree::Node{T}, + tree::AbstractExpressionNode{T}, ::Val{n_gradients}, - index_tree::NodeIndex, + index_tree::Union{NodeIndex,Nothing}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{variable}, @@ -231,7 +243,7 @@ function eval_grad_tree_array( end function eval_grad_tree_array( - tree::Node{T1}, + tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; variable::Bool=false, @@ -239,7 +251,7 @@ function eval_grad_tree_array( ) where {T1<:Number,T2<:Number} T = promote_type(T1, T2) return eval_grad_tree_array( - convert(Node{T}, tree), + convert(constructorof(typeof(tree)){T}, tree), convert(AbstractMatrix{T}, cX), operators; variable=variable, @@ -248,9 +260,9 @@ function eval_grad_tree_array( end function _eval_grad_tree_array( - tree::Node{T}, + tree::AbstractExpressionNode{T}, ::Val{n_gradients}, - index_tree::NodeIndex, + index_tree::Union{NodeIndex,Nothing}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{variable}, @@ -288,9 +300,9 @@ function _eval_grad_tree_array( end function grad_deg0_eval( - tree::Node{T}, + tree::AbstractExpressionNode{T}, ::Val{n_gradients}, - index_tree::NodeIndex, + index_tree::Union{NodeIndex,Nothing}, cX::AbstractMatrix{T}, ::Val{variable}, )::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,n_gradients} @@ -306,16 +318,20 @@ function grad_deg0_eval( return (const_part, zero_mat, true) end - index = variable ? tree.feature : index_tree.constant_index + index = if variable + tree.feature + else + (index_tree === nothing ? zero(UInt16) : index_tree.val::UInt16) + end derivative_part = zero_mat derivative_part[index, :] .= one(T) return (const_part, derivative_part, true) end function grad_deg1_eval( - tree::Node{T}, + tree::AbstractExpressionNode{T}, ::Val{n_gradients}, - index_tree::NodeIndex, + index_tree::Union{NodeIndex,Nothing}, cX::AbstractMatrix{T}, op::F, diff_op::dF, @@ -326,7 +342,13 @@ function grad_deg1_eval( AbstractVector{T},AbstractMatrix{T},Bool } where {T<:Number,F,dF,variable,turbo,n_gradients} (cumulator, dcumulator, complete) = eval_grad_tree_array( - tree.l, Val(n_gradients), index_tree.l, cX, operators, Val(variable), Val(turbo) + tree.l, + Val(n_gradients), + (index_tree === nothing ? index_tree : index_tree.l), + cX, + operators, + Val(variable), + Val(turbo), ) @return_on_false2 complete cumulator dcumulator @@ -343,9 +365,9 @@ function grad_deg1_eval( end function grad_deg2_eval( - tree::Node{T}, + tree::AbstractExpressionNode{T}, ::Val{n_gradients}, - index_tree::NodeIndex, + index_tree::Union{NodeIndex,Nothing}, cX::AbstractMatrix{T}, op::F, diff_op::dF, @@ -356,11 +378,23 @@ function grad_deg2_eval( AbstractVector{T},AbstractMatrix{T},Bool } where {T<:Number,F,dF,variable,turbo,n_gradients} (cumulator1, dcumulator1, complete) = eval_grad_tree_array( - tree.l, Val(n_gradients), index_tree.l, cX, operators, Val(variable), Val(turbo) + tree.l, + Val(n_gradients), + (index_tree === nothing ? index_tree : index_tree.l), + cX, + operators, + Val(variable), + Val(turbo), ) @return_on_false2 complete cumulator1 dcumulator1 (cumulator2, dcumulator2, complete2) = eval_grad_tree_array( - tree.r, Val(n_gradients), index_tree.r, cX, operators, Val(variable), Val(turbo) + tree.r, + Val(n_gradients), + (index_tree === nothing ? index_tree : index_tree.r), + cX, + operators, + Val(variable), + Val(turbo), ) @return_on_false2 complete2 cumulator1 dcumulator1 diff --git a/src/EvaluationHelpers.jl b/src/EvaluationHelpers.jl index 1b2f81da..bbde3b81 100644 --- a/src/EvaluationHelpers.jl +++ b/src/EvaluationHelpers.jl @@ -2,13 +2,13 @@ module EvaluationHelpersModule import Base: adjoint import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum -import ..EquationModule: Node +import ..EquationModule: AbstractExpressionNode import ..EvaluateEquationModule: eval_tree_array import ..EvaluateEquationDerivativeModule: eval_grad_tree_array # Evaluation: """ - (tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false) + (tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false) Evaluate a binary tree (equation) over a given data matrix. The operators contain all of the operators used in the tree. @@ -23,13 +23,13 @@ operators contain all of the operators used in the tree. Any NaN, Inf, or other failure during the evaluation will result in the entire output array being set to NaN. """ -function (tree::Node)(X, operators::OperatorEnum; kws...) +function (tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...) out, did_finish = eval_tree_array(tree, X, operators; kws...) !did_finish && (out .= convert(eltype(out), NaN)) return out end """ - (tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true) + (tree::AbstractExpressionNode)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true) # Arguments - `X::AbstractArray`: The input data to evaluate the tree on. @@ -49,26 +49,30 @@ end that it was not defined for. You can change this behavior by setting `throw_errors=false`. """ -function (tree::Node)(X, operators::GenericOperatorEnum; kws...) +function (tree::AbstractExpressionNode)(X, operators::GenericOperatorEnum; kws...) out, did_finish = eval_tree_array(tree, X, operators; kws...) !did_finish && return nothing return out end # Gradients: -function _grad_evaluator(tree::Node, X, operators::OperatorEnum; variable=true, kws...) +function _grad_evaluator( + tree::AbstractExpressionNode, X, operators::OperatorEnum; variable=true, kws... +) _, grad, did_complete = eval_grad_tree_array( tree, X, operators; variable=variable, kws... ) !did_complete && (grad .= convert(eltype(grad), NaN)) return grad end -function _grad_evaluator(tree::Node, X, operators::GenericOperatorEnum; kws...) +function _grad_evaluator( + tree::AbstractExpressionNode, X, operators::GenericOperatorEnum; kws... +) return error("Gradients are not implemented for `GenericOperatorEnum`.") end """ - (tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true) + (tree::AbstractExpressionNode{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true) Compute the forward-mode derivative of an expression, using a similar structure and optimization to eval_tree_array. `variable` specifies whether @@ -88,6 +92,7 @@ to every constant in the expression. - `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation, the gradient, and whether the evaluation completed as normal (or encountered a nan or inf). """ -Base.adjoint(tree::Node) = ((args...; kws...) -> _grad_evaluator(tree, args...; kws...)) +Base.adjoint(tree::AbstractExpressionNode) = + ((args...; kws...) -> _grad_evaluator(tree, args...; kws...)) end diff --git a/src/ExtensionInterface.jl b/src/ExtensionInterface.jl index a2d16cb9..63cb7418 100644 --- a/src/ExtensionInterface.jl +++ b/src/ExtensionInterface.jl @@ -1,12 +1,8 @@ module ExtensionInterfaceModule -import ..EquationModule: Node, DEFAULT_NODE_TYPE -import ..OperatorEnumModule: AbstractOperatorEnum -import ..UtilsModule: isgood, isbad, @return_on_false - function node_to_symbolic(args...; kws...) return error( - "Please load the `SymbolicUtils` package to use `node_to_symbolic(::Node, ::AbstractOperatorEnum; kws...)`.", + "Please load the `SymbolicUtils` package to use `node_to_symbolic(::AbstractExpressionNode, ::AbstractOperatorEnum; kws...)`.", ) end function symbolic_to_node(args...; kws...) diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index 8d18ef4b..5fd4394a 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -1,7 +1,7 @@ module OperatorEnumConstructionModule import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum -import ..EquationModule: string_tree, Node +import ..EquationModule: string_tree, Node, GraphNode, AbstractExpressionNode, constructorof import ..EvaluateEquationModule: eval_tree_array import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient import ..EvaluationHelpersModule: _grad_evaluator @@ -29,31 +29,30 @@ const ALREADY_DEFINED_BINARY_OPERATORS = (; ) const LATEST_VARIABLE_NAMES = Ref{Vector{String}}(String[]) -function Base.show(io::IO, tree::Node) +function Base.show(io::IO, tree::AbstractExpressionNode) latest_operators_type = LATEST_OPERATORS_TYPE.x + kwargs = (variable_names=LATEST_VARIABLE_NAMES.x,) if latest_operators_type == IsNothing - return print(io, string_tree(tree; variable_names=LATEST_VARIABLE_NAMES.x)) + return print(io, string_tree(tree; kwargs...)) elseif latest_operators_type == IsOperatorEnum latest_operators = LATEST_OPERATORS.x::OperatorEnum - return print( - io, string_tree(tree, latest_operators; variable_names=LATEST_VARIABLE_NAMES.x) - ) + return print(io, string_tree(tree, latest_operators; kwargs...)) else latest_operators = LATEST_OPERATORS.x::GenericOperatorEnum - return print( - io, string_tree(tree, latest_operators; variable_names=LATEST_VARIABLE_NAMES.x) - ) + return print(io, string_tree(tree, latest_operators; kwargs...)) end end -function (tree::Node)(X; kws...) +function (tree::AbstractExpressionNode)(X; kws...) Base.depwarn( "The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.", - :Node, + :AbstractExpressionNode, ) latest_operators_type = LATEST_OPERATORS_TYPE.x - if latest_operators_type == IsNothing + + latest_operators_type == IsNothing && error("Please use the `tree(X, operators; kws...)` syntax instead.") - elseif latest_operators_type == IsOperatorEnum + + if latest_operators_type == IsOperatorEnum latest_operators = LATEST_OPERATORS.x::OperatorEnum return tree(X, latest_operators; kws...) else @@ -62,32 +61,33 @@ function (tree::Node)(X; kws...) end end -function _grad_evaluator(tree::Node, X; kws...) +function _grad_evaluator(tree::AbstractExpressionNode, X; kws...) Base.depwarn( "The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.", - :Node, + :AbstractExpressionNode, ) latest_operators_type = LATEST_OPERATORS_TYPE.x # return _grad_evaluator(tree, X, $operators; kws...) - if latest_operators_type == IsNothing + latest_operators_type == IsNothing && error("Please use the `tree'(X, operators; kws...)` syntax instead.") - elseif latest_operators_type == IsOperatorEnum - latest_operators = LATEST_OPERATORS.x::OperatorEnum - return _grad_evaluator(tree, X, latest_operators; kws...) - else + latest_operators_type == IsGenericOperatorEnum && error("Gradients are not implemented for `GenericOperatorEnum`.") - end + + latest_operators = LATEST_OPERATORS.x::OperatorEnum + return _grad_evaluator(tree, X, latest_operators; kws...) end function set_default_variable_names!(variable_names::Vector{String}) - return LATEST_VARIABLE_NAMES.x = variable_names + return LATEST_VARIABLE_NAMES.x = copy(variable_names) end -function create_evaluation_helpers!(operators::OperatorEnum) +Base.@deprecate create_evaluation_helpers! set_default_operators! + +function set_default_operators!(operators::OperatorEnum) LATEST_OPERATORS.x = operators return LATEST_OPERATORS_TYPE.x = IsOperatorEnum end -function create_evaluation_helpers!(operators::GenericOperatorEnum) +function set_default_operators!(operators::GenericOperatorEnum) LATEST_OPERATORS.x = operators return LATEST_OPERATORS_TYPE.x = IsGenericOperatorEnum end @@ -96,81 +96,104 @@ function lookup_op(@nospecialize(f), ::Val{degree}) where {degree} mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING if !haskey(mapping, f) error( - "Convenience constructor for `Node` using operator `$(f)` is out-of-date. " * - "Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " * - "`define_helper_functions=true` and pass `$(f)`.", + "Convenience constructor for operator `$(f)` is out-of-date. " * + "Please create an `OperatorEnum` (or `GenericOperatorEnum`) containing " * + "the operator `$(f)` which will define the `$(f)` -> `Int` mapping.", ) end return mapping[f] end -function _extend_unary_operator(f::Symbol, type_requirements) +function _extend_unary_operator(f::Symbol, type_requirements, internal) quote + @gensym _constructorof _AbstractExpressionNode quote - function $($f)(l::Node{T})::Node{T} where {T<:$($type_requirements)} + if $$internal + import ..EquationModule.constructorof as $_constructorof + import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode + else + using DynamicExpressions: + constructorof as $_constructorof, + AbstractExpressionNode as $_AbstractExpressionNode + end + + function $($f)( + l::N + ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} return if (l.degree == 0 && l.constant) - Node(T; val=$($f)(l.val::T)) + $_constructorof(N)(T; val=$($f)(l.val::T)) else latest_op_idx = $($lookup_op)($($f), Val(1)) - Node(latest_op_idx, l) + $_constructorof(N)(latest_op_idx, l) end end end end end -function _extend_binary_operator(f::Symbol, type_requirements, build_converters) +function _extend_binary_operator(f::Symbol, type_requirements, build_converters, internal) quote + @gensym _constructorof _AbstractExpressionNode quote - function $($f)(l::Node{T}, r::Node{T}) where {T<:$($type_requirements)} + if $$internal + import ..EquationModule.constructorof as $_constructorof + import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode + else + using DynamicExpressions: + constructorof as $_constructorof, + AbstractExpressionNode as $_AbstractExpressionNode + end + + function $($f)( + l::N, r::N + ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} if (l.degree == 0 && l.constant && r.degree == 0 && r.constant) - Node(T; val=$($f)(l.val::T, r.val::T)) + $_constructorof(N)(T; val=$($f)(l.val::T, r.val::T)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - Node(latest_op_idx, l, r) + $_constructorof(N)(latest_op_idx, l, r) end end - function $($f)(l::Node{T}, r::T) where {T<:$($type_requirements)} + function $($f)( + l::N, r::T + ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} if l.degree == 0 && l.constant - Node(T; val=$($f)(l.val::T, r)) + $_constructorof(N)(T; val=$($f)(l.val::T, r)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - Node(latest_op_idx, l, Node(T; val=r)) + $_constructorof(N)(latest_op_idx, l, $_constructorof(N)(T; val=r)) end end - function $($f)(l::T, r::Node{T}) where {T<:$($type_requirements)} + function $($f)( + l::T, r::N + ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} if r.degree == 0 && r.constant - Node(T; val=$($f)(l, r.val::T)) + $_constructorof(N)(T; val=$($f)(l, r.val::T)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - Node(latest_op_idx, Node(T; val=l), r) + $_constructorof(N)(latest_op_idx, $_constructorof(N)(T; val=l), r) end end if $($build_converters) # Converters: - function $($f)( - l::Node{T1}, r::Node{T2} - ) where {T1<:$($type_requirements),T2<:$($type_requirements)} - T = promote_type(T1, T2) - l = convert(Node{T}, l) - r = convert(Node{T}, r) - return $($f)(l, r) + function $($f)(l::$_AbstractExpressionNode, r::$_AbstractExpressionNode) + if l isa GraphNode || r isa GraphNode + error( + "Refusing to promote `GraphNode` as it would break the graph structure. " * + "Please convert to a common type first.", + ) + end + return $($f)(promote(l, r)...) end function $($f)( - l::Node{T1}, r::T2 + l::$_AbstractExpressionNode{T1}, r::T2 ) where {T1<:$($type_requirements),T2<:$($type_requirements)} - T = promote_type(T1, T2) - l = convert(Node{T}, l) - r = convert(T, r) - return $($f)(l, r) + return $($f)(l, convert(T1, r)) end function $($f)( - l::T1, r::Node{T2} + l::T1, r::$_AbstractExpressionNode{T2} ) where {T1<:$($type_requirements),T2<:$($type_requirements)} - T = promote_type(T1, T2) - l = convert(T, l) - r = convert(Node{T}, r) - return $($f)(l, r) + return $($f)(convert(T2, l), r) end end end @@ -178,34 +201,47 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters) end function _extend_operators(operators, skip_user_operators, kws, __module__::Module) - empty_old_operators = - if length(kws) == 1 && :empty_old_operators in map(x -> x.args[1], kws) - @assert kws[1].head == :(=) - kws[1].args[2] - elseif length(kws) > 0 - error( - "You passed the keywords $(kws), but only `empty_old_operators` is supported.", - ) - else - true - end - binary_ex = _extend_binary_operator(:f, :type_requirements, :build_converters) - unary_ex = _extend_unary_operator(:f, :type_requirements) + if !all(x -> first(x.args) ∈ (:empty_old_operators, :internal), kws) + error( + "You passed the keywords $(kws), but only `empty_old_operators`, `internal` are supported.", + ) + end + + empty_old_operators_idx = findfirst(x -> first(x.args) == :empty_old_operators, kws) + internal_idx = findfirst(x -> first(x.args) == :internal, kws) + + empty_old_operators = if empty_old_operators_idx !== nothing + @assert kws[empty_old_operators_idx].head == :(=) + kws[empty_old_operators_idx].args[2] + else + true + end + + internal = if internal_idx !== nothing + @assert kws[internal_idx].head == :(=) + kws[internal_idx].args[2]::Bool + else + false + end + + @gensym f skip type_requirements build_converters binary_exists unary_exists + binary_ex = _extend_binary_operator(f, type_requirements, build_converters, internal) + unary_ex = _extend_unary_operator(f, type_requirements, internal) return quote - local type_requirements - local build_converters - local binary_exists - local unary_exists + local $type_requirements + local $build_converters + local $binary_exists + local $unary_exists if isa($operators, $OperatorEnum) - type_requirements = Number - build_converters = true - binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum - unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum + $type_requirements = Number + $build_converters = true + $binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum + $unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum else - type_requirements = Any - build_converters = false - binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum - unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum + $type_requirements = Any + $build_converters = false + $binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum + $unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum end if $(empty_old_operators) # Trigger errors if operators are not yet defined: @@ -213,39 +249,39 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu empty!($(LATEST_UNARY_OPERATOR_MAPPING)) end for (op, func) in enumerate($(operators).binops) - local f = Symbol(func) - local skip = false - if isdefined(Base, f) - f = :(Base.$(f)) + local $f = Symbol(func) + local $skip = false + if isdefined(Base, $f) + $f = :(Base.$($f)) elseif $(skip_user_operators) - skip = true + $skip = true else - f = :($($__module__).$(f)) + $f = :($($__module__).$($f)) end $(LATEST_BINARY_OPERATOR_MAPPING)[func] = op - skip && continue + $skip && continue # Avoid redefining methods: - if !haskey(unary_exists, func) + if !haskey($unary_exists, func) eval($binary_ex) - unary_exists[func] = true + $(unary_exists)[func] = true end end for (op, func) in enumerate($(operators).unaops) - local f = Symbol(func) - local skip = false - if isdefined(Base, f) - f = :(Base.$(f)) + local $f = Symbol(func) + local $skip = false + if isdefined(Base, $f) + $f = :(Base.$($f)) elseif $(skip_user_operators) - skip = true + $skip = true else - f = :($($__module__).$(f)) + $f = :($($__module__).$($f)) end $(LATEST_UNARY_OPERATOR_MAPPING)[func] = op - skip && continue + $skip && continue # Avoid redefining methods: - if !haskey(binary_exists, func) + if !haskey($binary_exists, func) eval($unary_ex) - binary_exists[func] = true + $(binary_exists)[func] = true end end end @@ -279,6 +315,8 @@ end Similar to `@extend_operators`, but only extends operators already defined in `Base`. +`kws` can include `empty_old_operators` which is default `true`, +and `internal` which is default `false`. """ macro extend_operators_base(operators, kws...) ex = _extend_operators(operators, true, kws, __module__) @@ -299,8 +337,8 @@ end empty_old_operators::Bool=true) Construct an `OperatorEnum` object, defining the possible expressions. This will also -redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`. -It will automatically compute derivatives with `Zygote.jl`. +redefine operators for `AbstractExpressionNode` types, as well as `show`, `print`, and +`(::AbstractExpressionNode)(X)`. It will automatically compute derivatives with `Zygote.jl`. # Arguments - `binary_operators::Vector{Function}`: A vector of functions, each of which is a binary @@ -343,7 +381,7 @@ function OperatorEnum(; if define_helper_functions @extend_operators_base operators empty_old_operators = empty_old_operators - create_evaluation_helpers!(operators) + set_default_operators!(operators) end return operators @@ -355,8 +393,8 @@ end Construct a `GenericOperatorEnum` object, defining possible expressions. Unlike `OperatorEnum`, this enum one will work arbitrary operators and data types. -This will also redefine operators for `Node` types, as well as `show`, `print`, -and `(::Node)(X)`. +This will also redefine operators for `AbstractExpressionNode` types, as well as `show`, `print`, +and `(::AbstractExpressionNode)(X)`. # Arguments - `binary_operators::Vector{Function}`: A vector of functions, each of which is a binary @@ -383,10 +421,31 @@ function GenericOperatorEnum(; if define_helper_functions @extend_operators_base operators empty_old_operators = empty_old_operators - create_evaluation_helpers!(operators) + set_default_operators!(operators) end return operators end +# Predefine the most common operators so the errors +# are more informative +function _overload_common_operators() + #! format: off + operators = OperatorEnum( + Function[+, -, *, /, ^, max, min, mod], + Function[ + sin, cos, tan, exp, log, log1p, log2, log10, sqrt, cbrt, abs, sinh, + cosh, tanh, atan, asinh, acosh, round, sign, floor, ceil, + ], + Function[], + Function[], + ) + #! format: on + @extend_operators(operators, empty_old_operators = false, internal = true) + empty!(LATEST_UNARY_OPERATOR_MAPPING) + empty!(LATEST_BINARY_OPERATOR_MAPPING) + return nothing +end +_overload_common_operators() + end diff --git a/src/SimplifyEquation.jl b/src/SimplifyEquation.jl index 709c1939..ccd6bd80 100644 --- a/src/SimplifyEquation.jl +++ b/src/SimplifyEquation.jl @@ -1,15 +1,24 @@ module SimplifyEquationModule -import ..EquationModule: Node, copy_node +import ..EquationModule: AbstractExpressionNode, constructorof, Node, copy_node, set_node! +import ..EquationUtilsModule: tree_mapreduce, is_node_constant import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: isbad, isgood _una_op_kernel(f::F, l::T) where {F,T} = f(l) _bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r) -# Simplify tree +is_commutative(::typeof(*)) = true +is_commutative(::typeof(+)) = true +is_commutative(_) = false + +is_subtraction(::typeof(-)) = true +is_subtraction(_) = false + +# This is only defined for `Node` as it is not possible for +# `GraphNode`. function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T} - # NOTE: (const (+*-) const) already accounted for. Call simplify_tree before. + # NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before. # ((const + var) + const) => (const + var) # ((const * var) * const) => (const * var) # ((const - var) - const) => (const - var) @@ -25,12 +34,8 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where end top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant) - if tree.degree == 2 && - (operators.binops[tree.op] == (*) || operators.binops[tree.op] == (+)) && - top_level_constant - + if tree.degree == 2 && is_commutative(operators.binops[tree.op]) && top_level_constant # TODO: Does this break SymbolicRegression.jl due to the different names of operators? - op = tree.op # Put the constant in r. Need to assume var in left for simplification assumption. if tree.l.constant @@ -56,16 +61,17 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where end end - if tree.degree == 2 && operators.binops[tree.op] == (-) && top_level_constant + if tree.degree == 2 && is_subtraction(operators.binops[tree.op]) && top_level_constant + # Currently just simplifies subtraction. (can't assume both plus and sub are operators) # Not commutative, so use different op. if tree.l.constant - if tree.r.degree == 2 && operators.binops[tree.r.op] == (-) + if tree.r.degree == 2 && tree.op == tree.r.op if tree.r.l.constant #(const - (const - var)) => (var - const) l = tree.l r = tree.r - simplified_const = -(l.val::T - r.l.val::T) #neg(sub(l.val, r.l.val)) + simplified_const = (r.l.val::T - l.val::T) #neg(sub(l.val, r.l.val)) tree.l = tree.r.r tree.r = l tree.r.val = simplified_const @@ -79,7 +85,7 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where end end else #tree.r.constant is true - if tree.l.degree == 2 && operators.binops[tree.l.op] == (-) + if tree.l.degree == 2 && tree.op == tree.l.op if tree.l.l.constant #((const - var) - const) => (const - var) l = tree.l @@ -102,43 +108,29 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where return tree end -# Simplify tree -# TODO: This will get much more powerful with the tree-map functions. -function simplify_tree(tree::Node{T}, operators::AbstractOperatorEnum) where {T} - if tree.degree == 1 - tree.l = simplify_tree(tree.l, operators) - if tree.l.degree == 0 && tree.l.constant - l = tree.l.val::T - if isgood(l) - out = _una_op_kernel(operators.unaops[tree.op], l) - if isbad(out) - return tree - end - return Node(T; val=convert(T, out)) - end - end - elseif tree.degree == 2 - tree.l = simplify_tree(tree.l, operators) - tree.r = simplify_tree(tree.r, operators) - constantsBelow = ( - tree.l.degree == 0 && tree.l.constant && tree.r.degree == 0 && tree.r.constant - ) - if constantsBelow - # NaN checks: - l = tree.l.val::T - r = tree.r.val::T - if isbad(l) || isbad(r) - return tree - end - - # Actually compute: - out = _bin_op_kernel(operators.binops[tree.op], l, r) - if isbad(out) - return tree - end - return Node(T; val=convert(T, out)) - end +function combine_children!(operators, p::N, c::N...) where {T,N<:AbstractExpressionNode{T}} + all(is_node_constant, c) || return p + vals = map(n -> n.val::T, c) + all(isgood, vals) || return p + out = if length(c) == 1 + _una_op_kernel(operators.unaops[p.op], vals...) + else + _bin_op_kernel(operators.binops[p.op], vals...) end + isgood(out) || return p + new_node = constructorof(N)(T; val=convert(T, out)) + set_node!(p, new_node) + return p +end + +# Simplify tree +function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum) + tree = tree_mapreduce( + identity, + (p, c...) -> combine_children!(operators, p, c...), + tree, + constructorof(typeof(tree)); + ) return tree end diff --git a/src/Utils.jl b/src/Utils.jl index 320227fa..12e47cc3 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -81,7 +81,7 @@ isgood(x) = true isbad(x) = !isgood(x) """ - @memoize_on tree function my_function_on_tree(tree::Node) + @memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode) ... end @@ -90,23 +90,36 @@ function with an additional `id_map` argument. When passed this argument (an IdDict()), it will use use the `id_map` to avoid recomputing the same value for the same node in a tree. Use this to automatically create functions that work with trees that have shared child nodes. + +Can optionally take a `postprocess` function, which will be applied to the +result of the function before returning it, taking the result as the +first argument and a boolean for whether the result was memoized as the +second argument. This is useful for functions that need to count the number +of unique nodes in a tree, for example. """ -macro memoize_on(tree, def) - idmap_def = _memoize_on(tree, def) +macro memoize_on(tree, args...) + if length(args) ∉ (1, 2) + error("Expected 2 or 3 arguments to @memoize_on") + end + postprocess = length(args) == 1 ? :((r, _) -> r) : args[1] + def = length(args) == 1 ? args[1] : args[2] + idmap_def = _memoize_on(tree, postprocess, def) + return quote $(esc(def)) # The normal function $(esc(idmap_def)) # The function with an id_map argument end end -function _memoize_on(tree::Symbol, def::Expr) +function _memoize_on(tree::Symbol, postprocess, def) sdef = splitdef(def) # Add an id_map argument - push!(sdef[:args], :(id_map::IdDict)) + push!(sdef[:args], :(id_map::AbstractDict)) f_name = sdef[:name] - # Add id_map argument to all calls within the function: + # Forward id_map argument to all calls of the same function + # within the function body: sdef[:body] = postwalk(sdef[:body]) do ex if @capture(ex, f_(args__)) if f == f_name @@ -117,10 +130,19 @@ function _memoize_on(tree::Symbol, def::Expr) end # Wrap the function body in a get!(id_map, tree) do ... end block: + @gensym key is_memoized result body sdef[:body] = quote - get!(id_map, $(tree)) do - $(sdef[:body]) + $key = objectid($tree) + $is_memoized = haskey(id_map, $key) + function $body() + return $(sdef[:body]) + end + $result = if $is_memoized + @inbounds(id_map[$key]) + else + id_map[$key] = $body() end + return $postprocess($result, $is_memoized) end return combinedef(sdef) @@ -150,7 +172,7 @@ macro with_memoize(def, id_map) end end -function _add_idmap_to_call(def::Expr, id_map::Expr) +function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr}) @assert def.head == :call return Expr(:call, def.args[1], def.args[2:end]..., id_map) end @@ -170,4 +192,11 @@ function deprecate_varmap(variable_names, varMap, func_name) return variable_names end +""" + Undefined + +Just a type like `Nothing` to differentiate from a literal `Nothing`. +""" +struct Undefined end + end diff --git a/src/base.jl b/src/base.jl index adafeb3e..f825e4f5 100644 --- a/src/base.jl +++ b/src/base.jl @@ -23,11 +23,17 @@ import Base: reduce, sum import Compat: @inline, Returns -import ..UtilsModule: @memoize_on, @with_memoize +import ..UtilsModule: @memoize_on, @with_memoize, Undefined """ - tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing) - tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing) + tree_mapreduce( + f::Function, + [f_branch::Function,] + op::Function, + tree::AbstractNode, + f_on_shared::Function=(result, is_shared) -> result, + break_sharing::Val=Val(false), + ) Map a function over a tree and aggregate the result using an operator `op`. `op` should be defined with inputs `(parent, child...) ->` so that it can aggregate @@ -35,6 +41,7 @@ both unary and binary operators. `op` will not be called for leafs of the tree. This differs from a normal `mapreduce` in that it allows different treatment for parent nodes than children nodes. If this is not necessary, you may use the regular `mapreduce` instead. +The argument `break_sharing` can be used to break connections in a `GraphNode`. You can also provide separate functions for leaf (variable/constant) nodes and branch (operator) nodes. @@ -69,24 +76,26 @@ function tree_mapreduce( f::F, op::G, tree::AbstractNode, - result_type::Type{RT}=Nothing; - preserve_sharing::Bool=false, -) where {F<:Function,G<:Function,RT} - return tree_mapreduce(f, f, op, tree, result_type; preserve_sharing) + result_type::Type=Undefined; + f_on_shared::H=(result, is_shared) -> result, + break_sharing=Val(false), +) where {F<:Function,G<:Function,H<:Function} + return tree_mapreduce(f, f, op, tree, result_type; f_on_shared, break_sharing) end function tree_mapreduce( f_leaf::F1, f_branch::F2, op::G, tree::AbstractNode, - result_type::Type{RT}=Nothing; - preserve_sharing::Bool=false, -) where {F1<:Function,F2<:Function,G<:Function,RT} + result_type::Type{RT}=Undefined; + f_on_shared::H=(result, is_shared) -> result, + break_sharing::Val=Val(false), +) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT} # Trick taken from here: # https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5 # to speed up recursive closure - @memoize_on t function inner(inner, t) + @memoize_on t f_on_shared function inner(inner, t) if t.degree == 0 return @inline(f_leaf(t)) elseif t.degree == 1 @@ -96,16 +105,30 @@ function tree_mapreduce( end end - RT == Nothing && - preserve_sharing && - throw(ArgumentError("Need to specify `result_type` if you use `preserve_sharing`.")) + sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false) + + RT == Undefined && + sharing && + throw(ArgumentError("Need to specify `result_type` if nodes are shared..")) - if preserve_sharing && RT != Nothing - return @with_memoize inner(inner, tree) IdDict{typeof(tree),RT}() + if sharing && RT != Undefined + d = allocate_id_map(tree, RT) + return @with_memoize inner(inner, tree) d else return inner(inner, tree) end end +function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT} + d = Dict{UInt,RT}() + # Preallocate maximum storage (counting with duplicates is fast) + N = length(tree; break_sharing=Val(true)) + sizehint!(d, N) + return d +end + +# TODO: Raise Julia issue for this. +# Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here! +# I think it's because `setindex!` is declared with `@nospecialize` in IdDict. """ any(f::Function, tree::AbstractNode) @@ -123,20 +146,61 @@ function any(f::F, tree::AbstractNode) where {F<:Function} end end -function Base.:(==)(a::AbstractNode, b::AbstractNode)::Bool +function Base.:(==)(a::AbstractExpressionNode, b::AbstractExpressionNode) + return Base.:(==)(promote(a, b)...) +end +function Base.:(==)(a::N, b::N)::Bool where {N<:AbstractExpressionNode} + if preserve_sharing(N) + return inner_is_equal_shared(a, b, Dict{UInt,Nothing}(), Dict{UInt,Nothing}()) + else + return inner_is_equal(a, b) + end +end +function inner_is_equal(a, b) (degree = a.degree) != b.degree && return false if degree == 0 return isequal_deg0(a, b) elseif degree == 1 - return isequal_deg1(a, b) && a.l == b.l + return isequal_deg1(a, b) && inner_is_equal(a.l, b.l) + else + return isequal_deg2(a, b) && inner_is_equal(a.l, b.l) && inner_is_equal(a.r, b.r) + end +end +function inner_is_equal_shared(a, b, id_map_a, id_map_b) + id_a = objectid(a) + id_b = objectid(b) + has_a = haskey(id_map_a, id_a) + has_b = haskey(id_map_b, id_b) + + if has_a && has_b + return true + elseif has_a ⊻ has_b + return false + end + + (degree = a.degree) != b.degree && return false + + result = if degree == 0 + isequal_deg0(a, b) + elseif degree == 1 + isequal_deg1(a, b) && inner_is_equal_shared(a.l, b.l, id_map_a, id_map_b) else - return isequal_deg2(a, b) && a.l == b.l && a.r == b.r + isequal_deg2(a, b) && + inner_is_equal_shared(a.l, b.l, id_map_a, id_map_b) && + inner_is_equal_shared(a.r, b.r, id_map_a, id_map_b) end + + id_map_a[id_a] = nothing + id_map_b[id_b] = nothing + + return result end -@inline isequal_deg1(a::Node, b::Node) = a.op == b.op -@inline isequal_deg2(a::Node, b::Node) = a.op == b.op -@inline function isequal_deg0(a::Node{T1}, b::Node{T2}) where {T1,T2} +@inline isequal_deg1(a::AbstractExpressionNode, b::AbstractExpressionNode) = a.op == b.op +@inline isequal_deg2(a::AbstractExpressionNode, b::AbstractExpressionNode) = a.op == b.op +@inline function isequal_deg0( + a::AbstractExpressionNode{T1}, b::AbstractExpressionNode{T2} +) where {T1,T2} (constant = a.constant) != b.constant && return false if constant return a.val::T1 == b.val::T2 @@ -150,12 +214,33 @@ end ############################################################################### """ - foreach(f::Function, tree::Node) + count_nodes(tree::AbstractNode)::Int + +Count the number of nodes in the tree. +""" +function count_nodes(tree::AbstractNode; break_sharing=Val(false)) + return tree_mapreduce( + _ -> 1, + +, + tree, + Int64; + f_on_shared=(c, is_shared) -> is_shared ? 0 : c, + break_sharing, + ) +end + +""" + foreach(f::Function, tree::AbstractNode) Apply a function to each node in a tree. """ -function foreach(f::Function, tree::AbstractNode) - return tree_mapreduce(t -> (@inline(f(t)); nothing), Returns(nothing), tree) +function foreach( + f::F, tree::AbstractNode; break_sharing::Val=Val(false) +) where {F<:Function} + tree_mapreduce( + t -> (@inline(f(t)); nothing), Returns(nothing), tree, Nothing; break_sharing + ) + return nothing end """ @@ -167,10 +252,14 @@ specifying the `result_type` of `map_fnc` so the resultant array can be preallocated. """ function filter_map( - filter_fnc::F, map_fnc::G, tree::AbstractNode, result_type::Type{GT} + filter_fnc::F, + map_fnc::G, + tree::AbstractNode, + result_type::Type{GT}; + break_sharing::Val=Val(false), ) where {F<:Function,G<:Function,GT} - stack = Array{GT}(undef, count(filter_fnc, tree)) - filter_map!(filter_fnc, map_fnc, stack, tree) + stack = Array{GT}(undef, count(filter_fnc, tree; init=0, break_sharing)) + filter_map!(filter_fnc, map_fnc, stack, tree; break_sharing) return stack::Vector{GT} end @@ -180,10 +269,14 @@ end Equivalent to `filter_map`, but stores the results in a preallocated array. """ function filter_map!( - filter_fnc::Function, map_fnc::Function, destination::Vector{GT}, tree::AbstractNode -) where {GT} + filter_fnc::F, + map_fnc::G, + destination::Vector{GT}, + tree::AbstractNode; + break_sharing::Val=Val(false), +) where {GT,F<:Function,G<:Function} pointer = Ref(0) - foreach(tree) do t + foreach(tree; break_sharing) do t if @inline(filter_fnc(t)) map_result = @inline(map_fnc(t))::GT @inbounds destination[pointer.x += 1] = map_result @@ -197,115 +290,166 @@ end Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`. """ -function filter(f::F, tree::AbstractNode) where {F<:Function} - return filter_map(f, identity, tree, typeof(tree)) +function filter(f::F, tree::AbstractNode; break_sharing::Val=Val(false)) where {F<:Function} + return filter_map(f, identity, tree, typeof(tree); break_sharing) end -collect(tree::AbstractNode) = filter(Returns(true), tree) +function collect(tree::AbstractNode; break_sharing::Val=Val(false)) + return filter(Returns(true), tree; break_sharing) +end """ map(f::Function, tree::AbstractNode, result_type::Type{RT}=Nothing) Map a function over a tree and return a flat array of the results in depth-first order. -Pre-specifying the `result_type` of the function can be used to avoid extra allocations, +Pre-specifying the `result_type` of the function can be used to avoid extra allocations. """ -function map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing) where {F<:Function,RT} +function map( + f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false) +) where {F<:Function,RT} if RT == Nothing - return f.(collect(tree)) + return f.(collect(tree; break_sharing)) else - return filter_map(Returns(true), f, tree, result_type) + return filter_map(Returns(true), f, tree, result_type; break_sharing) end end -function count(f::F, tree::AbstractNode; init=0) where {F<:Function} - return tree_mapreduce(t -> @inline(f(t)) ? 1 : 0, +, tree) + init +function count( + f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false) +) where {F<:Function} + return tree_mapreduce( + t -> @inline(f(t)) ? 1 : 0, + +, + tree, + Int64; + f_on_shared=(c, is_shared) -> is_shared ? 0 : c, + break_sharing, + ) + init end -function sum(f::F, tree::AbstractNode; init=0) where {F<:Function} - return tree_mapreduce(f, +, tree) + init +function sum( + f::F, + tree::AbstractNode; + init=0, + return_type=Undefined, + f_on_shared=(c, is_shared) -> is_shared ? (false * c) : c, + break_sharing::Val=Val(false), +) where {F<:Function} + if preserve_sharing(typeof(tree)) + @assert typeof(return_type) !== Undefined "Must specify `return_type` as a keyword argument to `sum` if `preserve_sharing` is true." + end + return tree_mapreduce(f, +, tree, return_type; f_on_shared, break_sharing) + init end all(f::F, tree::AbstractNode) where {F<:Function} = !any(t -> !@inline(f(t)), tree) -function mapreduce(f::F, op::G, tree::AbstractNode) where {F<:Function,G<:Function} - return tree_mapreduce(f, (n...) -> reduce(op, n), tree) +function mapreduce( + f::F, + op::G, + tree::AbstractNode; + return_type=Undefined, + f_on_shared=(c, is_shared) -> is_shared ? (false * c) : c, + break_sharing::Val=Val(false), +) where {F<:Function,G<:Function} + if preserve_sharing(typeof(tree)) + @assert typeof(return_type) !== Undefined "Must specify `return_type` as a keyword argument to `mapreduce` if `preserve_sharing` is true." + end + return tree_mapreduce( + f, (n...) -> reduce(op, n), tree, return_type; f_on_shared, break_sharing + ) end isempty(::AbstractNode) = false -iterate(root::AbstractNode) = (root, collect(root)[(begin + 1):end]) +function iterate(root::AbstractNode) + return (root, collect(root; break_sharing=Val(true))[(begin + 1):end]) +end iterate(::AbstractNode, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack) in(item, tree::AbstractNode) = any(t -> t == item, tree) -length(tree::AbstractNode) = sum(Returns(1), tree) -function hash(tree::Node{T}) where {T} +function length(tree::AbstractNode; break_sharing::Val=Val(false)) + return count_nodes(tree; break_sharing) +end + +function hash(tree::AbstractExpressionNode{T}) where {T} return tree_mapreduce( t -> t.constant ? hash((0, t.val::T)) : hash((1, t.feature)), t -> hash((t.degree + 1, t.op)), (n...) -> hash(n), tree, + UInt64; + f_on_shared=(cur_hash, is_shared) -> + is_shared ? hash((:shared, cur_hash)) : cur_hash, ) end """ - copy_node(tree::Node; preserve_sharing::Bool=false) + copy_node(tree::AbstractExpressionNode) Copy a node, recursively copying all children nodes. This is more efficient than the built-in copy. -With `preserve_sharing=true`, this will also -preserve linkage between a node and -multiple parents, whereas without, this would create -duplicate child node copies. id_map is a map from `objectid(tree)` to `copy(tree)`. We check against the map before making a new copy; otherwise we can simply reference the existing copy. [Thanks to Ted Hopp.](https://stackoverflow.com/questions/49285475/how-to-copy-a-full-non-binary-tree-including-loops) - -Note that this will *not* preserve loops in graphs. """ -function copy_node(tree::N; preserve_sharing::Bool=false) where {T,N<:Node{T}} +function copy_node( + tree::N; break_sharing::Val=Val(false) +) where {T,N<:AbstractExpressionNode{T}} return tree_mapreduce( - t -> t.constant ? Node(; val=t.val::T) : Node(T; feature=t.feature), + t -> if t.constant + constructorof(N)(; val=t.val::T) + else + constructorof(N)(T; feature=t.feature) + end, identity, - (p, c...) -> Node(p.op, c...), + (p, c...) -> constructorof(N)(p.op, c...), tree, N; - preserve_sharing, + break_sharing, ) end -copy(tree::Node; kws...) = copy_node(tree; kws...) +function copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false)) + return copy_node(tree; break_sharing) +end """ - convert(::Type{Node{T1}}, n::Node{T2}) where {T1,T2} + convert(::Type{AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2} -Convert a `Node{T2}` to a `Node{T1}`. -This will recursively convert all children nodes to `Node{T1}`, +Convert a `AbstractExpressionNode{T2}` to a `AbstractExpressionNode{T1}`. +This will recursively convert all children nodes to `AbstractExpressionNode{T1}`, using `convert(T1, tree.val)` at constant nodes. # Arguments -- `::Type{Node{T1}}`: Type to convert to. -- `tree::Node{T2}`: Node to convert. +- `::Type{AbstractExpressionNode{T1}}`: Type to convert to. +- `tree::AbstractExpressionNode{T2}`: AbstractExpressionNode to convert. """ function convert( - ::Type{Node{T1}}, tree::Node{T2}; preserve_sharing::Bool=false -) where {T1,T2} - if T1 == T2 + ::Type{N1}, tree::N2 +) where {T1,T2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2}} + if N1 === N2 return tree end return tree_mapreduce( t -> if t.constant - Node(T1, 0, true, convert(T1, t.val::T2)) + constructorof(N1)(T1, 0, true, convert(T1, t.val::T2)) else - Node(T1, 0, false, nothing, t.feature) + constructorof(N1)(T1, 0, false, nothing, t.feature) end, identity, - (p, c...) -> Node(p.degree, false, nothing, 0, p.op, c...), + (p, c...) -> constructorof(N1)(p.degree, false, nothing, 0, p.op, c...), tree, - Node{T1}; - preserve_sharing, + N1, ) end -(::Type{Node{T}})(tree::Node; kws...) where {T} = convert(Node{T}, tree; kws...) +function convert( + ::Type{N1}, tree::N2 +) where {T2,N1<:AbstractExpressionNode,N2<:AbstractExpressionNode{T2}} + return convert(constructorof(N1){T2}, tree) +end +function (::Type{N})(tree::AbstractExpressionNode) where {N<:AbstractExpressionNode} + return convert(N, tree) +end for func in (:reduce, :foldl, :foldr, :mapfoldl, :mapfoldr) @eval begin diff --git a/src/deprecated.jl b/src/deprecated.jl index 2f52568c..de03e91a 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,3 +1,4 @@ import Base: @deprecate @deprecate set_constants set_constants! +@deprecate simplify_tree simplify_tree! diff --git a/src/precompile.jl b/src/precompile.jl index 9e4e9bb1..3101e0ae 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -111,10 +111,8 @@ function test_functions_on_trees(::Type{T}, operators) where {T} tree = Node(i_bin, a8, a7) end tree = convert(Node{T}, tree) - for preserve_sharing in [true, false] - tree = copy_node(tree; preserve_sharing) - set_node!(tree, copy_node(tree; preserve_sharing)) - end + tree = copy_node(tree) + set_node!(tree, tree) string_tree(tree, operators) count_nodes(tree) @@ -126,7 +124,7 @@ function test_functions_on_trees(::Type{T}, operators) where {T} get_constants(tree) set_constants!(tree, get_constants(tree)) combine_operators(tree, operators) - simplify_tree(tree, operators) + simplify_tree!(tree, operators) return nothing end diff --git a/test/test_base.jl b/test/test_base.jl index 73506362..3c779634 100644 --- a/test/test_base.jl +++ b/test/test_base.jl @@ -40,7 +40,7 @@ end @test objectid(first(collect(ctree))) == objectid(ctree) @test typeof(collect(ctree)) == Vector{Node{Float64}} @test length(collect(ctree)) == 24 - @test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) == 11.6 + @test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) ≈ 11.6 end @testset "count" begin @@ -111,7 +111,6 @@ end @test sum(map(_ -> 2, ctree)) == 24 * 2 @test sum(map(t -> t.degree == 1, ctree)) == 1 @test length(unique(map(objectid, copy_node(tree)))) == 24 - @test length(unique(map(objectid, copy_node(tree; preserve_sharing=true)))) == 24 - 3 map(t -> (t.degree == 0 && t.constant) ? (t.val *= 2) : nothing, ctree) @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2 local T = fieldtype(typeof(ctree), :degree) diff --git a/test/test_derivatives.jl b/test/test_derivatives.jl index f5922fda..f72cd471 100644 --- a/test/test_derivatives.jl +++ b/test/test_derivatives.jl @@ -158,7 +158,7 @@ tree = equation3(nx1, nx2, nx3) """Check whether the ordering of constant_list is the same as the ordering of node_index.""" function check_tree(tree::Node, node_index::NodeIndex, constant_list::AbstractVector) if tree.degree == 0 - (!tree.constant) || tree.val == constant_list[node_index.constant_index] + (!tree.constant) || tree.val == constant_list[node_index.val::UInt16] elseif tree.degree == 1 check_tree(tree.l, node_index.l, constant_list) else diff --git a/test/test_equality.jl b/test/test_equality.jl index ee6fb70d..220e63c3 100644 --- a/test/test_equality.jl +++ b/test/test_equality.jl @@ -14,11 +14,21 @@ tree = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 / x1) same_tree = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 / x1) @test tree == same_tree -copied_tree = copy_node(tree; preserve_sharing=true) +x1 = GraphNode(; feature=1) +x2 = GraphNode(; feature=2) +x3 = GraphNode(; feature=3) +tree = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 / x1) +copied_tree = copy_node(tree) @test tree == copied_tree -copied_tree2 = copy_node(tree; preserve_sharing=false) -@test tree == copied_tree2 +copied_tree2 = copy_node(tree; break_sharing=Val(true)) +@test tree != copied_tree2 + +# Another way to break shared nodes is by converting +# to `Node` and back: +copied_tree3 = GraphNode(Node(tree)) +@test copied_tree2 == copied_tree3 +@test tree != copied_tree3 modifed_tree = x1 + x2 * x1 - log(x2 * 3.2) + 1.5 * cos(x2 / x1) @test tree != modifed_tree @@ -33,12 +43,13 @@ modified_tree4 = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 * x1) modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2) @test tree != modified_tree5 -# Type should not matter if equivalent in the promoted type: -f64_tree = x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1) -f32_tree = x1 + x2 * x3 - log(x2 * 3.0f0) + 1.5f0 * cos(x2 / x1) -@test typeof(f64_tree) == Node{Float64} -@test typeof(f32_tree) == Node{Float32} +f64_tree = GraphNode{Float64}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) +f32_tree = GraphNode{Float32}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) +@test typeof(f64_tree) == GraphNode{Float64} +@test typeof(f32_tree) == GraphNode{Float32} -@test convert(Node{Float64}, f32_tree) == f64_tree +@test convert(GraphNode{Float64}, f32_tree) == f64_tree @test f64_tree == f32_tree + +@test Node(f64_tree) == Node(f32_tree) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index e0a39269..7c44791c 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -92,7 +92,7 @@ end # op(, ) tree = Node(1, Node(; val=3.0f0), Node(; val=4.0f0)) - @test repr(tree) == "(3.0 + 4.0)" + @test repr(tree) == "3.0 + 4.0" tree = convert(Node{T}, tree) truth = T(3.0f0) + T(4.0f0) @test DynamicExpressions.EvaluateEquationModule.deg2_l0_r0_eval( diff --git a/test/test_graphs.jl b/test/test_graphs.jl new file mode 100644 index 00000000..ec4e0dcd --- /dev/null +++ b/test/test_graphs.jl @@ -0,0 +1,396 @@ +using DynamicExpressions +using DynamicExpressions: NodeIndex +using Test +include("test_params.jl") + +@testset "Constructing trees with shared nodes" begin + operators = OperatorEnum(; + binary_operators=(+, -, *, ^, /, greater), unary_operators=(cos, exp, sin) + ) + x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] + + base_tree = cos(x1 - 3.2 * x2) - x1^3.2 + tree = sin(base_tree) + base_tree + + # The base tree is exactly the same: + @test tree.l.l === tree.r + @test hash(tree.l.l) == hash(tree.r) + + # Now, let's change something in the base tree: + old_tree = deepcopy(tree) + base_tree.l.l = x3 * x2 - 1.5 + + # Should change: + @test string_tree(tree, operators) != string_tree(old_tree, operators) + + # But the linkage should be preserved: + @test tree.l.l === tree.r + @test hash(tree.l.l) == hash(tree.r) + + # When we copy with the normal copy, the sharing breaks: + copy_without_sharing = copy_node(tree; break_sharing=Val(true)) + @test !(copy_without_sharing.l.l === copy_without_sharing.r) + + # But with the sharing preserved in the copy, it should be the same: + copy_with_sharing = copy_node(tree) + @test copy_with_sharing.l.l === copy_with_sharing.r + + # We can also tweak the new tree, and the edits should be propagated: + copied_base_tree = copy_with_sharing.l.l + # (First, assert that it is the same as the old base tree) + @test string_tree(copied_base_tree, operators) == string_tree(base_tree, operators) + + # Now, let's tweak the new tree's base tree: + copied_base_tree.l.l = x1 * x2 * 5.2 - exp(x3) + # "exp" should appear *twice* now: + copy_with_sharing + @test length(collect(eachmatch(r"exp", string_tree(copy_with_sharing, operators)))) == 2 + @test copy_with_sharing.l.l === copy_with_sharing.r + @test hash(copy_with_sharing.l.l) == hash(copy_with_sharing.r) + @test string_tree(copy_with_sharing.l.l, operators) != string_tree(base_tree, operators) + + # We also test whether `convert` breaks shared children. + # The node type here should be Float64. + @test typeof(tree).parameters[1] == Float64 + # Let's convert to Float32: + float32_tree = convert(GraphNode{Float32}, tree) + @test typeof(float32_tree).parameters[1] == Float32 + # The linkage should be kept: + @test float32_tree.l.l === float32_tree.r +end + +@testset "Macro tests" begin + # We also do tests of the macros related to generating functions that preserve + # sharing: + @eval begin + expr_eql(x::LineNumberNode, y::LineNumberNode) = true # Ignore line numbers + expr_eql(x::QuoteNode, y::QuoteNode) = + x == y ? true : (println(x, " and ", y, " are not equal"); false) + expr_eql(x::Number, y::Number) = + x == y ? true : (println(x, " and ", y, " are not equal"); false) + function expr_eql(x::Symbol, y::Symbol) + if x == y + return true + else + sx = string(x) + sy = string(y) + result = if startswith(sx, r"#") + occursin(sy, sx) + elseif startswith(sy, r"#") + occursin(sx, sy) + else + false + end + !result && println(x, " and ", y, " are not equal") + return result + end + end + function expr_eql(x::Expr, y::Expr) + # Remove line numbers from the arguments: + x.args = filter(c -> !isa(c, LineNumberNode), x.args) + y.args = filter(c -> !isa(c, LineNumberNode), y.args) + + if expr_eql(x.head, y.head) && + length(x.args) == length(y.args) && + all(expr_eql.(x.args, y.args)) + return true + else + println(x, " and ", y, " are not equal") + return false + end + end + expr_eql(x, y) = error("Unexpected type: $(typeof(x)) or $(typeof(y))") + end + + @testset "Macro testing utils" begin + # First, assert this test actually works: + @test !expr_eql( + :(_convert(Node{T1}, tree)), + :(_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())), + ) + end + + @testset "@with_memoize" begin + ex = @macroexpand DynamicExpressions.UtilsModule.@with_memoize( + _convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}() + ) + true_ex = quote + _convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}()) + end + + @test expr_eql(ex, true_ex) + end + + @testset "@memoize_on" begin + ex = @macroexpand DynamicExpressions.UtilsModule.@memoize_on tree ((x, _) -> x) function _copy_node( + tree::Node{T} + )::Node{T} where {T} + if tree.degree == 0 + if tree.constant + Node(; val=copy(tree.val::T)) + else + Node(T; feature=copy(tree.feature)) + end + elseif tree.degree == 1 + Node(copy(tree.op), _copy_node(tree.l)) + else + Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) + end + end + true_ex = quote + function _copy_node(tree::Node{T})::Node{T} where {T} + if tree.degree == 0 + if tree.constant + Node(; val=copy(tree.val::T)) + else + Node(T; feature=copy(tree.feature)) + end + elseif tree.degree == 1 + Node(copy(tree.op), _copy_node(tree.l)) + else + Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) + end + end + function _copy_node(tree::Node{T}, id_map::AbstractDict;)::Node{T} where {T} + key = objectid(tree) + is_memoized = haskey(id_map, key) + function body() + return begin + if tree.degree == 0 + if tree.constant + Node(; val=copy(tree.val::T)) + else + Node(T; feature=copy(tree.feature)) + end + elseif tree.degree == 1 + Node(copy(tree.op), _copy_node(tree.l, id_map)) + else + Node( + copy(tree.op), + _copy_node(tree.l, id_map), + _copy_node(tree.r, id_map), + ) + end + end + end + result = if is_memoized + begin + $(Expr(:inbounds, true)) + local val = id_map[key] + $(Expr(:inbounds, :pop)) + val + end + else + id_map[key] = body() + end + return (((x, _) -> begin + x + end)(result, is_memoized)) + end + end + @test expr_eql(ex, true_ex) + end +end + +@testset "Operations on graphs" begin + operators = OperatorEnum(; + binary_operators=(+, -, *, ^, /), unary_operators=(cos, exp, sin) + ) + function make_tree() + x1, x2 = GraphNode(Float64; feature=1), GraphNode(Float64; feature=2) + base_tree = + cos(x1 - 3.2 * x2) - x1^3.5 + + GraphNode(3, GraphNode(; val=0.3), GraphNode(; val=0.9)) + tree = sin(base_tree) + base_tree + return base_tree, tree + end + + @testset "Strings" begin + x1 = GraphNode(Float64; feature=1) + n = x1 + x1 + @test string_tree(copy_node(n; break_sharing=Val(true)), operators) == "x1 + x1" + @test string_tree(n, operators) == "x1 + {x1}" + + # Copying the node explicitly changes the behavior: + x1 = GraphNode(Float64; feature=1) + n = x1 + copy(x1) + @test string_tree(n, operators) == "x1 + x1" + + # But, note that if we do a type conversion, the connection is also lost: + x1 = GraphNode(Float64; feature=1) + n = x1 + 3.5 * x1 + @test string_tree(n, operators) == "x1 + (3.5 * {x1})" + @test string_tree(copy_node(n; break_sharing=Val(true)), operators) == + "x1 + (3.5 * x1)" + + base_tree, tree = make_tree() + + s = string_tree(copy_node(base_tree; break_sharing=Val(true)), operators) + @test s == "(cos(x1 - (3.2 * x2)) - (x1 ^ 3.5)) + (0.3 * 0.9)" + s = string_tree(base_tree, operators) + @test s == "(cos(x1 - (3.2 * x2)) - ({x1} ^ 3.5)) + (0.3 * 0.9)" + s = string_tree(tree, operators) + @test s == + "sin((cos(x1 - (3.2 * x2)) - ({x1} ^ 3.5)) + (0.3 * 0.9)) + {((cos(x1 - (3.2 * x2)) - ({x1} ^ 3.5)) + (0.3 * 0.9))}" + # ^ Note the {} indicating shared subexpression + end + + @testset "Counting nodes" begin + base_tree, tree = make_tree() + + @test count_nodes(base_tree; break_sharing=Val(true)) == 14 + @test count_nodes(tree; break_sharing=Val(true)) == 30 + + # One shared node, so -1: + @test count_nodes(base_tree) == 13 + + # sin and the +, so +2 from above: + @test count_nodes(tree) == 15 + + @test count_depth(tree) == 8 + @test count_depth(base_tree) == 6 + end + + @testset "Simplification" begin + base_tree, tree = make_tree() + simplify_tree!(base_tree, operators) + # Simplifies both sides without error: + @test string_tree(tree, operators) == + "sin((cos(x1 - (3.2 * x2)) - ({x1} ^ 3.5)) + 0.27) + {((cos(x1 - (3.2 * x2)) - ({x1} ^ 3.5)) + 0.27)}" + end + + @testset "Hashing" begin + x = GraphNode(; feature=1) + x2 = GraphNode(; feature=1) + tree = GraphNode(1, x, x) + tree2 = GraphNode(1, x2, x2) + @test hash(tree) == hash(tree2) + @test hash(tree) != hash(copy_node(tree; break_sharing=Val(true))) + @test hash(copy_node(tree; break_sharing=Val(true))) == + hash(copy_node(tree; break_sharing=Val(true))) + @test hash(Node(tree)) == hash(copy_node(tree; break_sharing=Val(true))) + end + + @testset "Constants" begin + base_tree, tree = make_tree() + @test count_constants(tree) == 4 + @test count_constants(copy_node(tree; break_sharing=Val(true))) == 8 + @test count_constants(copy_node(tree)) == 4 + @test get_constants(tree) == [3.2, 3.5, 0.3, 0.9] + @test get_constants(copy_node(tree; break_sharing=Val(true))) == + [3.2, 3.5, 0.3, 0.9, 3.2, 3.5, 0.3, 0.9] + + c = get_constants(tree) + c .+= 1.2 + set_constants!(tree, c) + @test get_constants(tree) == [4.4, 4.7, 1.5, 2.1] + # Note that this means all constants in the shared expression are set the same way: + @test get_constants(copy_node(tree; break_sharing=Val(true))) == + [4.4, 4.7, 1.5, 2.1, 4.4, 4.7, 1.5, 2.1] + + # What about a single constant? + f1 = GraphNode(; val=1.0) + @test get_constants(f1) == [1.0] + f2 = GraphNode(1, f1, f1) + @test get_constants(f2) == [1.0] + @test string_tree(f2, operators) == "1.0 + {1.0}" + + # Now, we can test indexing: + base_tree, tree = make_tree() + node_index = index_constants(tree) + @eval function get_indices(n::NodeIndex{T}) where {T} + return filter_map(t -> t.degree == 0 && !iszero(t.val), t -> t.val, n, T) + end + # Note that the node index does not use shared nodes, + # as this would be redundant (since we are already + # tracing the original expression when using a node index): + @test get_indices(node_index) == [1, 2, 3, 4, 1, 2, 3, 4] + @test tree.r.l.l.l.r.l == GraphNode(Float64; val=3.2) + @test node_index.r.l.l.l.r.l.val == 1 + end + + @testset "Various base utils" begin + x = GraphNode(; feature=1) + tree = GraphNode(1, x, x) + @test collect(tree) == [tree, x] + @test collect(tree; break_sharing=Val(true)) == [tree, x, x] + @test filter(node -> node.degree == 0, tree) == [x] + @test filter(node -> node.degree == 0, tree; break_sharing=Val(true)) == [x, x] + @test count(_ -> true, tree) == 2 + @test count(_ -> true, tree; break_sharing=Val(true)) == 3 + c = typeof(x)[] + foreach(tree) do n + push!(c, n) + end + @test c == [tree, x] + c = typeof(x)[] + foreach(tree; break_sharing=Val(true)) do n + push!(c, n) + end + @test c == [tree, x, x] + + # Note that iterate always turns on `break_sharing`! + c = typeof(x)[] + for n in tree + push!(c, n) + end + @test c == [tree, x, x] + @test length(tree) == 2 + @test length(tree; break_sharing=Val(true)) == 3 + @test map(t -> t.degree == 0 ? 1 : 0, tree) == [0, 1] + @test map(t -> t.degree == 0 ? 1 : 0, tree; break_sharing=Val(true)) == [0, 1, 1] + @test mapreduce(t -> t.degree == 0 ? 1 : 0, +, tree; return_type=Int) == 1 + @test mapreduce(t -> t.degree == 0 ? 1 : 0, +, tree; break_sharing=Val(true)) == 2 + @test sum(t -> t.degree == 0 ? 1 : 0, tree; return_type=Int) == 1 + @test sum(t -> t.degree == 0 ? 1 : 0, tree; break_sharing=Val(true)) == 2 + @test filter_map(t -> t.degree == 0, t -> 1, tree, Int) == [1] + @test filter_map(t -> t.degree == 0, t -> 1, tree, Int; break_sharing=Val(true)) == + [1, 1] + end + + @testset "(lack of) automatic conversion" begin + operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos, exp]) + x = GraphNode(Float32; feature=1) + tree = x + 1.0 + @test tree.l === x + @test typeof(tree) === GraphNode{Float32} + + # Detect error from Float32(1im) + @test_throws InexactError x + 1im + @test x + 0im isa GraphNode{Float32} + + # Detect error from Int(1.5) + x = GraphNode(Int; feature=1) + @test_throws InexactError x + 1.5 + + x = GraphNode(ComplexF64; feature=1) + @test hash(x + 1) == hash(GraphNode(1, x, GraphNode(; val=1.0 + 0.0im))) + @test (x + 1).l === x + end +end + +@testset "Joint operations" begin + operators = OperatorEnum(; + binary_operators=(+, -, *, ^, /), unary_operators=(cos, exp, sin) + ) + x = GraphNode(Float64; feature=1) + y = Node(Float64; feature=1) + + @test x == y + + @test promote(x, y) isa Tuple{typeof(x),typeof(x)} + + # Node with GraphNode - will convert both + tree1 = sin(x) * x + tree2 = sin(y) * y + @test tree1 != tree2 + + # GraphNode against GraphNode + tree1 = sin(x) * x + tree2 = sin(x) * x + @test tree1 == tree2 + + # Is aware of different shared structure + tree2 = sin(x) * GraphNode(Float64; feature=1) + @test tree1 != tree2 +end diff --git a/test/test_preserve_multiple_parents.jl b/test/test_preserve_multiple_parents.jl deleted file mode 100644 index 2065e694..00000000 --- a/test/test_preserve_multiple_parents.jl +++ /dev/null @@ -1,150 +0,0 @@ -using DynamicExpressions -using Test -include("test_params.jl") - -@testset "Trees with shared nodes" begin - operators = OperatorEnum(; - binary_operators=(+, -, *, ^, /, greater), unary_operators=(cos, exp, sin) - ) - x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") - - base_tree = cos(x1 - 3.2 * x2) - x1^3.2 - tree = sin(base_tree) + base_tree - - # The base tree is exactly the same: - @test tree.l.l === tree.r - @test hash(tree.l.l) == hash(tree.r) - - # Now, let's change something in the base tree: - old_tree = deepcopy(tree) - base_tree.l.l = x3 * x2 - 1.5 - - # Should change: - @test string_tree(tree, operators) != string_tree(old_tree, operators) - - # But the linkage should be preserved: - @test tree.l.l === tree.r - @test hash(tree.l.l) == hash(tree.r) - - # When we copy with the normal copy, the sharing breaks: - copy_without_sharing = copy_node(tree) - @test !(copy_without_sharing.l.l === copy_without_sharing.r) - - # But with the sharing preserved in the copy, it should be the same: - copy_with_sharing = copy_node(tree; preserve_sharing=true) - @test copy_with_sharing.l.l === copy_with_sharing.r - - # We can also tweak the new tree, and the edits should be propagated: - copied_base_tree = copy_with_sharing.l.l - # (First, assert that it is the same as the old base tree) - @test string_tree(copied_base_tree, operators) == string_tree(base_tree, operators) - - # Now, let's tweak the new tree's base tree: - copied_base_tree.l.l = x1 * x2 * 5.2 - exp(x3) - # "exp" should appear *twice* now: - copy_with_sharing - @test length(collect(eachmatch(r"exp", string_tree(copy_with_sharing, operators)))) == 2 - @test copy_with_sharing.l.l === copy_with_sharing.r - @test hash(copy_with_sharing.l.l) == hash(copy_with_sharing.r) - @test string_tree(copy_with_sharing.l.l, operators) != string_tree(base_tree, operators) - - # We also test whether `convert` breaks shared children. - # The node type here should be Float64. - @test typeof(tree).parameters[1] == Float64 - # Let's convert to Float32: - float32_tree = convert(Node{Float32}, tree; preserve_sharing=true) - @test typeof(float32_tree).parameters[1] == Float32 - # The linkage should be kept: - @test float32_tree.l.l === float32_tree.r -end - -# We also do tests of the macros related to generating functions that preserve -# sharing: -expr_eql(x::LineNumberNode, y::LineNumberNode) = true # Ignore line numbers -expr_eql(x::QuoteNode, y::QuoteNode) = x == y -expr_eql(x::Number, y::Number) = x == y -expr_eql(x::Symbol, y::Symbol) = x == y -function expr_eql(x::Expr, y::Expr) - # Remove line numbers from the arguments: - x.args = filter(c -> !isa(c, LineNumberNode), x.args) - y.args = filter(c -> !isa(c, LineNumberNode), y.args) - - return expr_eql(x.head, y.head) && - length(x.args) == length(y.args) && - all(expr_eql.(x.args, y.args)) -end -expr_eql(x, y) = error("Unexpected type: $(typeof(x)) or $(typeof(y))") - -@testset "Macro testing utils" begin - # First, assert this test actually works: - @test !expr_eql( - :(_convert(Node{T1}, tree)), - :(_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())), - ) -end - -@testset "@with_memoize" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@with_memoize( - _convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}() - ) - true_ex = quote - _convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}()) - end - - @test expr_eql(ex, true_ex) -end - -@testset "@memoize_on" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@memoize_on tree function _copy_node( - tree::Node{T} - )::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val::T)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - true_ex = quote - function _copy_node(tree::Node{T})::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val::T)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - function _copy_node(tree::Node{T}, id_map::IdDict;)::Node{T} where {T} - get!(id_map, tree) do - begin - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val::T)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l, id_map)) - else - Node( - copy(tree.op), - _copy_node(tree.l, id_map), - _copy_node(tree.r, id_map), - ) - end - end - end - end - end - @test expr_eql(ex, true_ex) -end diff --git a/test/test_print.jl b/test/test_print.jl index 566ce462..c946789f 100644 --- a/test/test_print.jl +++ b/test/test_print.jl @@ -14,14 +14,14 @@ f = (x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0 tree = f(Node("x1"), Node("x2"), Node("x3")) s = repr(tree) -true_s = "((sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0)" +true_s = "(sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0" @test s == true_s # TODO: Next, we test that custom varMaps work: s = string_tree(tree, operators; variable_names=["v1", "v2", "v3"]) -true_s = "((sin(cos(sin(cos(v1) * v3) * 3.0) * -0.5) + 2.0) * 5.0)" +true_s = "(sin(cos(sin(cos(v1) * v3) * 3.0) * -0.5) + 2.0) * 5.0" @test s == true_s for unaop in [safe_log, safe_log2, safe_log10, safe_log1p, safe_sqrt, safe_acosh] @@ -39,7 +39,7 @@ for binop in [safe_pow, ^] default_params..., binary_operators=(+, *, /, -, binop), unary_operators=(cos,) ) minitree = Node(5, Node("x1"), Node("x2")) - @test string_tree(minitree, opts) == "(x1 ^ x2)" + @test string_tree(minitree, opts) == "x1 ^ x2" end @testset "Test print_tree function" begin @@ -56,7 +56,7 @@ end end close(pipe.in) s = read(pipe.out, String) - @test s == "((x1 * x1) + 0.5)\n" + @test s == "(x1 * x1) + 0.5\n" end end @@ -71,6 +71,7 @@ end x1, x2, x3 = [Node(; feature=i) for i in 1:3] tree = sin(x1 * 1.0) @test string_tree(tree, operators) == "sin(x1 * 1.0)" + x1 = convert(Node{ComplexF64}, x1) tree = sin(x1 * (1.0 + 2.0im)) @test string_tree(tree, operators) == "sin(x1 * (1.0 + 2.0im))" tree = my_custom_op(x1, 1.0 + 2.0im) @@ -84,19 +85,18 @@ end @extend_operators operators x1, x2, x3 = [Node(Float64; feature=i) for i in 1:3] tree = x1 * x1 + 0.5 - @test string_tree(tree, operators; f_constant=Returns("TEST")) == "((x1 * x1) + TEST)" - @test string_tree(tree, operators; f_variable=Returns("TEST")) == - "((TEST * TEST) + 0.5)" + @test string_tree(tree, operators; f_constant=Returns("TEST")) == "(x1 * x1) + TEST" + @test string_tree(tree, operators; f_variable=Returns("TEST")) == "(TEST * TEST) + 0.5" @test string_tree( tree, operators; f_variable=Returns("TEST"), f_constant=Returns("TEST2") - ) == "((TEST * TEST) + TEST2)" + ) == "(TEST * TEST) + TEST2" # Try printing with a precision: tree = x1 * x1 + π f_constant(val::Float64, args...) = string(round(val; digits=2)) - @test string_tree(tree, operators; f_constant=f_constant) == "((x1 * x1) + 3.14)" + @test string_tree(tree, operators; f_constant=f_constant) == "(x1 * x1) + 3.14" f_constant(val::Float64, args...) = string(round(val; digits=4)) - @test string_tree(tree, operators; f_constant=f_constant) == "((x1 * x1) + 3.1416)" + @test string_tree(tree, operators; f_constant=f_constant) == "(x1 * x1) + 3.1416" end @testset "Test variable names" begin @@ -107,11 +107,11 @@ end "k1", "k2", "k3" ] tree = x1 * x2 + x3 - @test string(tree) == "((k1 * k2) + k3)" + @test string(tree) == "(k1 * k2) + k3" empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x) - @test string(tree) == "((x1 * x2) + x3)" + @test string(tree) == "(x1 * x2) + x3" # Check if we can pass the wrong number of variable names: set_default_variable_names!(["k1"]) - @test string(tree) == "((k1 * x2) + x3)" + @test string(tree) == "(k1 * x2) + x3" empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x) end diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 1fdd93c5..c8841458 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -1,16 +1,21 @@ include("test_params.jl") using DynamicExpressions, Test +import DynamicExpressions.EquationModule: strip_brackets import SymbolicUtils: simplify, Symbolic import Random: MersenneTwister import Base: ≈ +strip_brackets(a::String) = String(strip_brackets(collect(a))) + function Base.:≈(a::String, b::String) + a = strip_brackets(a) + b = strip_brackets(b) a = replace(a, r"\s+" => "") b = replace(b, r"\s+" => "") return a == b end -simplify_tree = DynamicExpressions.SimplifyEquationModule.simplify_tree +simplify_tree! = DynamicExpressions.SimplifyEquationModule.simplify_tree! combine_operators = DynamicExpressions.SimplifyEquationModule.combine_operators binary_operators = (+, -, /, *) @@ -82,7 +87,7 @@ output3, flag3 = eval_tree_array(tree_copy2, X, operators) @test isapprox(output1, output3, atol=1e-2 * sqrt(N)) ############################################################################### -## Hit other parts of `simplify_tree` and `combine_operators` to increase +## Hit other parts of `simplify_tree!` and `combine_operators` to increase ## code coverage: operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin)) x1, x2, x3 = [Node(; feature=i) for i in 1:3] @@ -90,12 +95,12 @@ x1, x2, x3 = [Node(; feature=i) for i in 1:3] # unary operator applied to constant => constant: tree = Node(1, Node(; val=0.0)) @test repr(tree) ≈ "cos(0.0)" -@test repr(simplify_tree(tree, operators)) ≈ "1.0" +@test repr(simplify_tree!(tree, operators)) ≈ "1.0" # except when the result is a NaN, then we don't change it: tree = Node(1, Node(; val=NaN)) @test repr(tree) ≈ "cos(NaN)" -@test repr(simplify_tree(tree, operators)) ≈ "cos(NaN)" +@test repr(simplify_tree!(tree, operators)) ≈ "cos(NaN)" # the same as above, but inside a binary tree. tree = diff --git a/test/test_tree_construction.jl b/test/test_tree_construction.jl index ec7b8a19..e1210cee 100644 --- a/test/test_tree_construction.jl +++ b/test/test_tree_construction.jl @@ -90,17 +90,32 @@ for unaop in [cos, exp, safe_log, safe_log2, safe_log10, safe_sqrt, relu, gamma, end end -# We also test whether we can set a node equal to another node: -operators = OperatorEnum(; default_params...) -tree = Node(Float64; feature=1) -tree2 = exp(Node(; feature=2) / 3.2) + Node(; feature=1) * 2.0 - -# Test printing works: -io = IOBuffer() -print(io, tree2) -s = String(take!(io)) -@test s == "(exp(x2 / 3.2) + (x1 * 2.0))" - -set_node!(tree, tree2) -@test tree !== tree2 -@test repr(tree) == repr(tree2) +@testset "Set a node equal to another node" begin + operators = OperatorEnum(; default_params...) + tree = Node(Float64; feature=1) + tree2 = exp(Node(Float64; feature=2) / 3.2) + Node(Float64; feature=1) * 2.0 + + # Test printing works: + io = IOBuffer() + print(io, tree2) + s = String(take!(io)) + @test s == "exp(x2 / 3.2) + (x1 * 2.0)" + + set_node!(tree, tree2) + @test tree !== tree2 + @test repr(tree) == repr(tree2) +end + +@testset "Miscellaneous" begin + operators = OperatorEnum(; default_params...) + for N in (Node, GraphNode) + tree = N{ComplexF64}(; val=1) + @test typeof(tree.val) === ComplexF64 + + x = N{BigFloat}(; feature=1) + @test_throws AssertionError N{Float32}(1, x) + @test N{BigFloat}(1, x) == N(1, x) + @test typeof(N(1, x, N{Float32}(; val=1))) === N{BigFloat} + @test typeof(N(1, N{Float32}(; val=1), x)) === N{BigFloat} + end +end diff --git a/test/test_utils.jl b/test/test_utils.jl index 8523ad56..0b5409f5 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -34,7 +34,7 @@ set_constants!(tree, [1.0]) tree = x1 + Node(; val=0.0) - sin(x2 - Node(; val=0.5)) @test get_constants(tree) == [0.0, 0.5] set_constants!(tree, [1.0, 2.0]) -@test repr(tree) == "((x1 + 1.0) - sin(x2 - 2.0))" +@test repr(tree) == "(x1 + 1.0) - sin(x2 - 2.0)" # Ensure that fill_similar is type stable x = randn(Float32, 3, 10) diff --git a/test/unittest.jl b/test/unittest.jl index 94382530..c10bcbaf 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -49,7 +49,7 @@ end end @safetestset "Test sharing-preserving copy" begin - include("test_preserve_multiple_parents.jl") + include("test_graphs.jl") end @safetestset "Test equation utils" begin