diff --git a/Project.toml b/Project.toml index 56d883a9..63906584 100644 --- a/Project.toml +++ b/Project.toml @@ -17,17 +17,20 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" [extensions] DynamicExpressionsBumperExt = "Bumper" -DynamicExpressionsLoopVectorizationExt = "LoopVectorization" DynamicExpressionsOptimExt = "Optim" DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils" DynamicExpressionsZygoteExt = "Zygote" +DynamicExpressionsLoopVectorizationExt = "LoopVectorization" +DynamicExpressionsVisualizeExt = ["Plots","GraphRecipes"] [compat] Bumper = "0.6" @@ -35,7 +38,6 @@ ChainRulesCore = "1" Compat = "3.37, 4" DispatchDoctor = "0.4" Interfaces = "0.3" -LoopVectorization = "0.12" MacroTools = "0.4, 0.5" Optim = "0.19, 1" PackageExtensionCompat = "1" @@ -47,7 +49,6 @@ julia = "1.6" [extras] Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index 7edbd704..76e45927 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -1,7 +1,9 @@ module DynamicExpressionsLoopVectorizationExt -using LoopVectorization: @turbo -using DynamicExpressions: AbstractExpressionNode +using DynamicExpressions + +using LoopVectorization: @turbo, vmapnt +using DynamicExpressions: AbstractExpressionNode, GraphNode, OperatorEnum using DynamicExpressions.UtilsModule: ResultOk, fill_similar using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions import DynamicExpressions.EvaluateModule: @@ -14,6 +16,7 @@ import DynamicExpressions.EvaluateModule: deg2_r0_eval import DynamicExpressions.ExtensionInterfaceModule: _is_loopvectorization_loaded, bumper_kern1!, bumper_kern2! +import DynamicExpressions.ValueInterfaceModule: is_valid, is_valid_array _is_loopvectorization_loaded(::Int) = true @@ -230,4 +233,63 @@ function bumper_kern2!( return cumulator1 end + + +# graph eval + +function DynamicExpressions.EvaluateModule._eval_graph_array( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, + loopVectorization::Val{true} +) where {T} + + # vmap is faster with small cX sizes + # vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?) + + order = topological_sort(root) + for node in order + if node.degree == 0 && !node.constant + node.cache = view(cX, node.feature, :) + elseif node.degree == 1 + if node.l.constant + node.constant = true + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = vmapnt(operators.unaops[node.op], node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + elseif node.degree == 2 + if node.l.constant + if node.r.constant + node.constant = true + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = vmapnt(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + else + if node.r.constant + node.constant = false + node.cache = vmapnt(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + else + node.constant = false + node.cache = vmapnt(operators.binops[node.op], node.l.cache, node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + end + end + end + if root.constant + return ResultOk(fill(root.val, size(cX, 2)), true) + else + return ResultOk(root.cache, true) + end +end + end diff --git a/ext/DynamicExpressionsVisualizeExt.jl b/ext/DynamicExpressionsVisualizeExt.jl new file mode 100644 index 00000000..e422b618 --- /dev/null +++ b/ext/DynamicExpressionsVisualizeExt.jl @@ -0,0 +1,72 @@ +module DynamicExpressionsVisualizeExt + +using Plots, GraphRecipes, DynamicExpressions +using DynamicExpressions: GraphNode, Node, topological_sort, AbstractOperatorEnum, get_op_name + +function DynamicExpressions.visualize( + graph::Union{GraphNode,Node}, # types accepted by topological_sort + operators::AbstractOperatorEnum, + show = true +) + @info "Generating graph visualization" + + order = reverse(topological_sort(graph)) + + # multigraph adjacency list + g = map( + node -> convert(Vector{Int64}, map( + cindex -> findfirst(x -> x === node.children[cindex][], order), + 1:node.degree + )), + order + ) + + # node labels + n = map(x -> + if x.degree == 0 + x.constant ? x.val : 'x' * string(x.feature) + elseif x.degree == 1 + join(get_op_name(operators.unaops[x.op])) + elseif x.degree == 2 + join(get_op_name(operators.binops[x.op])) + else + @warn "Can't label operator node with degree > 2" + end, + order + ) + + # edge labels (specifies parameter no.) + e = Dict{Tuple{Int64, Int64, Int64}, String}() + for (index, node) in enumerate(order) + edge_count = Dict{Int64, Int64}() # count number of edges to each child node + for cindex in 1:node.degree + order_cindex = findfirst(x -> x === node.children[cindex][], order) + get!( + e, + ( + index, # source + order_cindex, # dest + get!(edge_count, order_cindex, pop!(edge_count, order_cindex, 0)+1) # edge no. + ), + string(cindex) + ) + end + end + + # node colours + c = map(x -> x == 1 ? 2 : 1, eachindex(order)) + + return graphplot( + g, + names = n, + edgelabel = e, + nodecolor = c, + show = show, + nodeshape=:circle, + edge_label_box = false, + edgelabel_offset = 0.015, + nodesize=0.15 + ) +end + +end \ No newline at end of file diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index e8699cb1..8a955628 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -21,6 +21,7 @@ using DispatchDoctor: @stable, @unstable include("Random.jl") include("Parse.jl") include("ParametricExpression.jl") + include("Visualize.jl") include("StructuredExpression.jl") end @@ -44,11 +45,14 @@ import .ValueInterfaceModule: set_node!, tree_mapreduce, filter_map, - filter_map! + filter_map!, + topological_sort, + randomised_topological_sort import .NodeModule: constructorof, with_type_parameters, preserve_sharing, + max_degree, leaf_copy, branch_copy, leaf_hash, @@ -66,8 +70,7 @@ import .NodeModule: count_scalar_constants, get_scalar_constants, set_scalar_constants! -@reexport import .StringsModule: string_tree, print_tree -import .StringsModule: get_op_name +@reexport import .StringsModule: string_tree, print_tree, get_op_name @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! @@ -93,6 +96,7 @@ import .ExpressionModule: @reexport import .ParseModule: @parse_expression, parse_expression import .ParseModule: parse_leaf @reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode +@reexport import .VisualizeModule: visualize @reexport import .StructuredExpressionModule: StructuredExpression @stable default_mode = "disable" begin @@ -104,6 +108,7 @@ end import .InterfacesModule: ExpressionInterface, NodeInterface, all_ei_methods_except, all_ni_methods_except + function __init__() @require_extensions end diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 2a00f047..2ff7f420 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -2,7 +2,7 @@ module EvaluateModule using DispatchDoctor: @stable, @unstable -import ..NodeModule: AbstractExpressionNode, constructorof +import ..NodeModule: AbstractExpressionNode, constructorof, GraphNode, topological_sort import ..StringsModule: string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk @@ -854,4 +854,149 @@ end end end +# Parametric arguments don't use dynamic dispatch, calls with turbo/bumper won't resolve properly + +# overwritten in ext/DynamicExpressionsLoopVectorizationExt.jl +function _eval_graph_array( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, + loopVectorization::Val{true} +) where {T} + error("_is_loopvectorization_loaded(0) is true but _eval_graph_array has not been overwritten") +end + +function _eval_graph_array( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, + loopVectorization::Val{false} +) where {T} +order = topological_sort(root) +for node in order + if node.degree == 0 && !node.constant + node.cache = view(cX, node.feature, :) + elseif node.degree == 1 + if node.l.constant + node.constant = true + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = map(operators.unaops[node.op], node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + elseif node.degree == 2 + if node.l.constant + if node.r.constant + node.constant = true + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = map(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + else + if node.r.constant + node.constant = false + node.cache = map(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + else + node.constant = false + node.cache = map(operators.binops[node.op], node.l.cache, node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + end + end +end +if root.constant + return ResultOk(fill(root.val, size(cX, 2)), true) +else + return ResultOk(root.cache, true) +end +end + +function eval_tree_array( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, +) where {T} + return _eval_graph_array(root, cX, operators, Val(_is_loopvectorization_loaded(0))) end + +function eval_graph_array_diff( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, +) where {T} + + # vmap is faster with small cX sizes + # vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?) + dp = Dict{GraphNode, AbstractArray{T}}() + order = topological_sort(root) + for node in order + if node.degree == 0 && !node.constant + dp[node] = view(cX, node.feature, :) + elseif node.degree == 1 + if node.l.constant + node.constant = true + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return false end + else + node.constant = false + dp[node] = map(operators.unaops[node.op], dp[node.l]) + if !is_valid_array(dp[node]) return false end + end + elseif node.degree == 2 + if node.l.constant + if node.r.constant + node.constant = true + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return false end + else + node.constant = false + dp[node] = map(Base.Fix1(operators.binops[node.op], node.l.val), dp[node.r]) + if !is_valid_array(dp[node]) return false end + end + else + if node.r.constant + node.constant = false + dp[node] = map(Base.Fix2(operators.binops[node.op], node.r.val), dp[node.l]) + if !is_valid_array(dp[node]) return false end + else + node.constant = false + dp[node] = map(operators.binops[node.op], dp[node.l], dp[node.r]) + if !is_valid_array(dp[node]) return false end + end + end + end + end + if root.constant + return fill(root.val, size(cX, 2)) + else + return dp[root] + end +end + +function eval_graph_single( + root::GraphNode{T}, + cX::AbstractArray{T}, + operators::OperatorEnum +) where {T} + order = topological_sort(root) + for node in order + if node.degree == 0 && !node.constant + node.val = cX[node.feature] + elseif node.degree == 1 + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return false end + elseif node.degree == 2 + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return false end + end + end + return root.val +end + +end \ No newline at end of file diff --git a/src/Node.jl b/src/Node.jl index 4355aea5..6f319d0b 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -3,30 +3,35 @@ module NodeModule using DispatchDoctor: @unstable import ..OperatorEnumModule: AbstractOperatorEnum -import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined +import ..UtilsModule: deprecate_varmap, Undefined +using Random: default_rng, AbstractRNG const DEFAULT_NODE_TYPE = Float32 """ - AbstractNode + AbstractNode{D} -Abstract type for binary trees. Must have the following fields: +Abstract type for D-arity trees. Must have the following fields: - `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1, then `l` needs to be defined as the left child. If 2, then `r` also needs to be defined as the right child. -- `l::AbstractNode`: Left child of the current node. Should only be +- `children`: A collection of D references to children nodes. + +# Deprecated fields + +- `l::AbstractNode{D}`: Left child of the current node. Should only be defined if `degree >= 1`; otherwise, leave it undefined (see the the constructors of [`Node{T}`](@ref) for an example). Don't use `nothing` to represent an undefined value as it will incur a large performance penalty. -- `r::AbstractNode`: Right child of the current node. Should only +- `r::AbstractNode{D}`: Right child of the current node. Should only be defined if `degree == 2`. """ -abstract type AbstractNode end +abstract type AbstractNode{D} end """ - AbstractExpressionNode{T} <: AbstractNode + AbstractExpressionNode{T,D} <: AbstractNode{D} Abstract type for nodes that represent an expression. Along with the fields required for `AbstractNode`, @@ -63,11 +68,42 @@ to your type. - `leaf_hash` and `branch_hash` - `preserve_sharing` """ -abstract type AbstractExpressionNode{T} <: AbstractNode end +abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end + +mutable struct Node{T,D} <: AbstractExpressionNode{T,D} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum + children::NTuple{D,Base.RefValue{Node{T,D}}} # Children nodes + + ################# + ## Constructors: + ################# + #Node{_T,_D}() where {_T,_D} = new{_T,_D::Int}() + Node{_T,_D}() where {_T,_D} = (x = new{_T,_D::Int}(); x.children = ntuple(i -> Ref{Node{_T,_D}}(), Val(max_degree(Node))); x) +end + +mutable struct GraphNode{T,D} <: AbstractExpressionNode{T,D} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum + children::NTuple{D,Base.RefValue{GraphNode{T,D}}} # Children nodes + visited::Bool # search accounting, initialised to false + cache::AbstractArray{T} + + ################# + ## Constructors: + ################# + GraphNode{_T,_D}() where {_T,_D} = (x = new{_T,_D::Int}(); x.visited = false; x.children = ntuple(i -> Ref{GraphNode{_T,_D}}(), Val(max_degree(GraphNode))); x) +end #! format: off """ - Node{T} <: AbstractExpressionNode{T} + Node{T,D} <: AbstractExpressionNode{T,D} Node defines a symbolic expression stored in a binary tree. A single `Node` instance is one "node" of this tree, and @@ -77,63 +113,42 @@ nodes, you can evaluate or print a given expression. # Fields - `degree::UInt8`: Degree of the node. 0 for constants, 1 for - unary operators, 2 for binary operators. + unary operators, 2 for binary operators, etc. Maximum of `D`. - `constant::Bool`: Whether the node is a constant. - `val::T`: Value of the node. If `degree==0`, and `constant==true`, this is the value of the constant. It has a type specified by the overall type of the `Node` (e.g., `Float64`). - `feature::UInt16`: Index of the feature to use in the - case of a feature node. Only used if `degree==0` and `constant==false`. - Only defined if `degree == 0 && constant == false`. + case of a feature node. Only defined if `degree == 0 && constant == false`. - `op::UInt8`: If `degree==1`, this is the index of the operator in `operators.unaops`. If `degree==2`, this is the index of the operator in `operators.binops`. In other words, this is an enum of the operators, and is dependent on the specific `OperatorEnum` object. Only defined if `degree >= 1` -- `l::Node{T}`: Left child of the node. Only defined if `degree >= 1`. - Same type as the parent node. -- `r::Node{T}`: Right child of the node. Only defined if `degree == 2`. - Same type as the parent node. This is to be passed as the right - argument to the binary operator. +- `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree` # Constructors - Node([T]; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) - Node{T}(; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) + Node([T]; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator) + Node{T}(; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator) Create a new node in an expression tree. If `T` is not specified in either the type or the -first argument, it will be inferred from the value of `val` passed or `l` and/or `r`. -If it cannot be inferred from these, it will default to `Float32`. - -The `children` keyword can be used instead of `l` and `r` and should be a tuple of children. This -is to permit the use of splatting in constructors. +first argument, it will be inferred from the value of `val` passed or the children. +The `children` keyword is used to pass in a collection of children nodes. You may also construct nodes via the convenience operators generated by creating an `OperatorEnum`. You may also choose to specify a default memory allocator for the node other than simply `Node{T}()` in the `allocator` keyword argument. """ -mutable struct Node{T} <: AbstractExpressionNode{T} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - # ------------------- (possibly undefined below) - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops - l::Node{T} # Left child node. Only defined for degree=1 or degree=2. - r::Node{T} # Right child node. Only defined for degree=2. +Node - ################# - ## Constructors: - ################# - Node{_T}() where {_T} = new{_T}() -end """ - GraphNode{T} <: AbstractExpressionNode{T} + GraphNode{T,D} <: AbstractExpressionNode{T,D} -Exactly the same as [`Node{T}`](@ref), but with the assumption that some +Exactly the same as [`Node{T,D}`](@ref), but with the assumption that some nodes will be shared. All copies of this graph-like structure will be performed with this assumption, to preserve structure of the graph. @@ -161,17 +176,42 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes are created simply by using the same node in multiple places when constructing or setting properties. """ -mutable struct GraphNode{T} <: AbstractExpressionNode{T} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - # ------------------- (possibly undefined below) - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops - l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2. - r::GraphNode{T} # Right child node. Only defined for degree=2. - - GraphNode{_T}() where {_T} = new{_T}() +GraphNode + +@inline function Base.getproperty(n::Union{Node,GraphNode}, k::Symbol) + if k == :l + # TODO: Should a depwarn be raised here? Or too slow? + return getfield(n, :children)[1][] + elseif k == :r + return getfield(n, :children)[2][] + else + return getfield(n, k) + end +end +@inline function Base.setproperty!(n::Union{Node,GraphNode}, k::Symbol, v) + if k == :l + getfield(n, :children)[1][] = v + elseif k == :r + getfield(n, :children)[2][] = v + elseif k == :degree + setfield!(n, :degree, convert(UInt8, v)) + elseif k == :constant + setfield!(n, :constant, convert(Bool, v)) + elseif k == :feature + setfield!(n, :feature, convert(UInt16, v)) + elseif k == :op + setfield!(n, :op, convert(UInt8, v)) + elseif k == :val + setfield!(n, :val, convert(eltype(n), v)) + elseif k == :children + setfield!(n, :children, v) + elseif k == :visited && typeof(n) <: GraphNode + setfield!(n, :visited, v) + elseif k == :cache && typeof(n) <: GraphNode + setfield!(n, :cache, v) + else + error("Invalid property: $k") + end end ################################################################################ @@ -180,25 +220,27 @@ end Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T -@unstable constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper -@unstable constructorof(::Type{<:Node}) = Node -@unstable constructorof(::Type{<:GraphNode}) = GraphNode +max_degree(::Type{<:AbstractNode}) = 2 # Default +max_degree(::Type{<:AbstractNode{D}}) where {D} = D + +@unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T} +@unstable constructorof(::Type{N}) where {N<:GraphNode} = + GraphNode{T,max_degree(N)} where {T} -function with_type_parameters(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} - return constructorof(N){T} +with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} = Node{T,max_degree(N)} +function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T} + return GraphNode{T,max_degree(N)} end -with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T} -with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T} + +# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D} +# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D} function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} return with_type_parameters(N, T)() end -default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}() -default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}() """Trait declaring whether nodes share children or not.""" preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false -preserve_sharing(::Union{Type{<:Node},Node}) = false preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true include("base.jl") @@ -206,33 +248,34 @@ include("base.jl") #! format: off @inline function (::Type{N})( ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, -) where {T1,N<:AbstractExpressionNode,F} - validate_not_all_defaults(N, val, feature, op, l, r, children) - if children !== nothing - @assert l === nothing && r === nothing - if length(children) == 1 - return node_factory(N, T1, val, feature, op, only(children), nothing, allocator) - else - return node_factory(N, T1, val, feature, op, children..., allocator) - end +) where {T1,N<:AbstractExpressionNode{T} where T,F} + _children = if l !== nothing && r === nothing + @assert children === nothing + (l,) + elseif l !== nothing && r !== nothing + @assert children === nothing + (l, r) + else + children end - return node_factory(N, T1, val, feature, op, l, r, allocator) + validate_not_all_defaults(N, val, feature, op, _children) + return node_factory(N, T1, val, feature, op, _children, allocator) end -function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {N<:AbstractExpressionNode} +function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode} return nothing end -function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {T,N<:AbstractExpressionNode{T}} - if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing +function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {T,N<:AbstractExpressionNode{T}} + if val === nothing && feature === nothing && op === nothing && children === nothing error( "Encountered the call for $N() inside the generic constructor. " - * "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?" + * "Did you forget to define `$(Base.typename(N).wrapper){T,D}() where {T,D} = new{T,D}()`?" ) end return nothing end """Create a constant leaf.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, ::Nothing, allocator::F, + ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,T2,F} T = node_factory_type(N, T1, T2) n = allocator(N, T) @@ -243,7 +286,7 @@ end end """Create a variable leaf, to store data.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, ::Nothing, allocator::F, + ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,F} T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) n = allocator(N, T) @@ -252,28 +295,18 @@ end n.feature = feature return n end -"""Create a unary operator node.""" +"""Create an operator node.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, ::Nothing, allocator::F, -) where {N,T1,T2,F} - @assert l isa N - T = T2 # Always prefer existing nodes, so we don't mess up references from conversion + ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Tuple, allocator::F, +) where {N<:AbstractExpressionNode,F} + T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion + D2 = length(children) + @assert D2 <= max_degree(N) + NT = with_type_parameters(N, T) n = allocator(N, T) - n.degree = 1 + n.degree = D2 n.op = op - n.l = l - return n -end -"""Create a binary operator node.""" -@inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, r::AbstractExpressionNode{T3}, allocator::F, -) where {N,T1,T2,T3,F} - T = promote_type(T2, T3) - n = allocator(N, T) - n.degree = 2 - n.op = op - n.l = T2 === T ? l : convert(with_type_parameters(N, T), l) - n.r = T3 === T ? r : convert(with_type_parameters(N, T), r) + n.children = ntuple(i -> i <= D2 ? Ref(convert(NT, children[i])) : Ref{NT}(), Val(max_degree(N))) return n end @@ -314,14 +347,14 @@ function (::Type{N})( return N(; feature=i) end -function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2} - return Node{promote_type(T1, T2)} +function Base.promote_rule(::Type{Node{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D} + return Node{promote_type(T1, T2),D} end -function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{Node{T2}}) where {T1,T2} - return GraphNode{promote_type(T1, T2)} +function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D} + return GraphNode{promote_type(T1, T2),D} end -function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2} - return GraphNode{promote_type(T1, T2)} +function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{GraphNode{T2,D}}) where {T1,T2,D} + return GraphNode{promote_type(T1, T2),D} end # TODO: Verify using this helps with garbage collection @@ -359,4 +392,93 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod return nothing end +"""Topological sort of the graph following a depth-first search""" +function topological_sort(graph::GraphNode) + order = Vector{GraphNode}() + _rec_toposort(graph, order) + for node in order + node.visited = false + end + return order +end + +"""Topological sort of the graph following a randomised depth-first search""" +function randomised_topological_sort(graph::GraphNode, rng::AbstractRNG=default_rng()) + order = Vector{GraphNode}() + _rec_randomised_toposort(graph, order, rng) + for node in order + node.visited = false + end + return order +end + +function _rec_toposort(gnode::GraphNode, order::Vector{GraphNode}) + if gnode.visited return end + gnode.visited = true + if gnode.degree == 1 + _rec_toposort(gnode.l, order) + elseif gnode.degree == 2 + _rec_toposort(gnode.l, order) + _rec_toposort(gnode.r, order) + end + push!(order, gnode) +end + +function _rec_randomised_toposort(gnode::GraphNode, order::Vector{GraphNode}, rng::AbstractRNG) + if gnode.visited return end + gnode.visited = true + if gnode.degree == 1 + _rec_randomised_toposort(gnode.l, order, rng) + elseif gnode.degree == 2 + if rand(rng, Bool) + _rec_randomised_toposort(gnode.l, order, rng) + _rec_randomised_toposort(gnode.r, order, rng) + else + _rec_randomised_toposort(gnode.r, order, rng) + _rec_randomised_toposort(gnode.l, order, rng) + end + end + push!(order, gnode) +end + + +"""Topological sort of the tree following a depth-first search""" +function topological_sort(tree::Node) + order = Vector{Node}() + _rec_toposort(tree, order) + return order +end + +"""Topological sort of the tree following a randomised depth-first search""" +function randomised_topological_sort(tree::Node, rng::AbstractRNG=default_rng()) + order = Vector{Node}() + _rec_randomised_toposort(tree, order, rng) + return order end + +function _rec_toposort(tnode::Node, order::Vector{Node}) + if tnode.degree == 1 + _rec_toposort(tnode.l, order) + elseif tnode.degree == 2 + _rec_toposort(tnode.l, order) + _rec_toposort(tnode.r, order) + end + push!(order, tnode) +end + +function _rec_randomised_toposort(tnode::Node, order::Vector{Node}, rng::AbstractRNG) + if tnode.degree == 1 + _rec_randomised_toposort(tnode.l, order, rng) + elseif tnode.degree == 2 + if rand(rng, Bool) + _rec_randomised_toposort(tnode.l, order, rng) + _rec_randomised_toposort(tnode.r, order, rng) + else + _rec_randomised_toposort(tnode.r, order, rng) + _rec_randomised_toposort(tnode.l, order, rng) + end + end + push!(order, gnode) +end + +end \ No newline at end of file diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index cef19b2c..392df3b7 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -143,38 +143,59 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -struct NodeIndex{T} <: AbstractNode +mutable struct NodeIndex{T,D} <: AbstractNode{D} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) - l::NodeIndex{T} # Left child node. Only defined for degree=1 or degree=2. - r::NodeIndex{T} # Right child node. Only defined for degree=2. - - NodeIndex(::Type{_T}) where {_T} = new{_T}(0, zero(_T)) - NodeIndex(::Type{_T}, val) where {_T} = new{_T}(0, convert(_T, val)) - NodeIndex(::Type{_T}, l::NodeIndex) where {_T} = new{_T}(1, zero(_T), l) - function NodeIndex(::Type{_T}, l::NodeIndex, r::NodeIndex) where {_T} - return new{_T}(2, zero(_T), l, r) + children::NTuple{D,Base.RefValue{NodeIndex{T,D}}} + + function NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} + return new{_T,_D}( + 0, convert(_T, val), ntuple(_ -> Ref{NodeIndex{_T,_D}}(), Val(_D)) + ) + end + function NodeIndex( + ::Type{_T}, ::Val{_D}, children::Vararg{NodeIndex{_T,_D},_D2} + ) where {_T,_D,_D2} + _children = ntuple( + i -> i <= _D2 ? Ref(children[i]) : Ref{NodeIndex{_T,_D}}(), Val(_D) + ) + return new{_T,_D}(convert(UInt8, _D2), zero(_T), _children) + end +end +NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T)) + +@inline function Base.getproperty(n::NodeIndex, k::Symbol) + if k == :l + # TODO: Should a depwarn be raised here? Or too slow? + return getfield(n, :children)[1][] + elseif k == :r + return getfield(n, :children)[2][] + else + return getfield(n, k) end end + # Sharing is never needed for NodeIndex, # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false -function index_constant_nodes(tree::AbstractExpressionNode, ::Type{T}=UInt16) where {T} +function index_constant_nodes( + tree::AbstractExpressionNode{Ti,D} where {Ti}, ::Type{T}=UInt16 +) where {D,T} # Essentially we copy the tree, replacing the values # with indices constant_index = Ref(T(0)) return tree_mapreduce( t -> if t.constant - NodeIndex(T, (constant_index[] += T(1))) + NodeIndex(T, Val(D), (constant_index[] += T(1))) else - NodeIndex(T) + NodeIndex(T, Val(D)) end, t -> nothing, - (_, c...) -> NodeIndex(T, c...), + (_, c...) -> NodeIndex(T, Val(D), c...), tree, - NodeIndex{T}; + NodeIndex{T,D}; ) end diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index 96b84d00..c5c557af 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -167,7 +167,7 @@ function _extend_unary_operator( $_constructorof(N)(T; val=$($f_inside)(l.val)) else latest_op_idx = $($lookup_op)($($f_inside), Val(1)) - $_constructorof(N)(; op=latest_op_idx, l) + $_constructorof(N)(; op=latest_op_idx, children=(l,)) end end end @@ -196,7 +196,7 @@ function _extend_binary_operator( $_constructorof(N)(T; val=$($f_inside)(l.val, r.val)) else latest_op_idx = $($lookup_op)($($f_inside), Val(2)) - $_constructorof(N)(; op=latest_op_idx, l, r) + $_constructorof(N)(; op=latest_op_idx, children=(l, r)) end end function $($f_outside)( @@ -207,7 +207,7 @@ function _extend_binary_operator( else latest_op_idx = $($lookup_op)($($f_inside), Val(2)) $_constructorof(N)(; - op=latest_op_idx, l, r=$_constructorof(N)(T; val=r) + op=latest_op_idx, children=(l, $_constructorof(N)(T; val=r)) ) end end @@ -219,7 +219,7 @@ function _extend_binary_operator( else latest_op_idx = $($lookup_op)($($f_inside), Val(2)) $_constructorof(N)(; - op=latest_op_idx, l=$_constructorof(N)(T; val=l), r + op=latest_op_idx, children=($_constructorof(N)(T; val=l), r) ) end end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 0eb5db04..8c75aa7b 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -41,7 +41,7 @@ import ..ValueInterfaceModule: count_scalar_constants, pack_scalar_constants!, unpack_scalar_constants """A type of expression node that also stores a parameter index""" -mutable struct ParametricNode{T} <: AbstractExpressionNode{T} +mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D} degree::UInt8 constant::Bool # if true => constant; if false, then check `is_parameter` val::T @@ -51,11 +51,10 @@ mutable struct ParametricNode{T} <: AbstractExpressionNode{T} parameter::UInt16 # Stores index of per-class parameter op::UInt8 - l::ParametricNode{T} - r::ParametricNode{T} + children::NTuple{D,Base.RefValue{ParametricNode{T,D}}} # Children nodes - function ParametricNode{_T}() where {_T} - n = new{_T}() + function ParametricNode{_T,_D}() where {_T,_D} + n = new{_T,_D}() n.is_parameter = false n.parameter = UInt16(0) return n diff --git a/src/Utils.jl b/src/Utils.jl index 3211908c..6db149ff 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -13,103 +13,6 @@ macro return_on_false2(flag, retval, retval2) ) end -""" - @memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode) - ... - end - -This macro takes a function definition and creates a second version of the -function with an additional `id_map` argument. When passed this argument (an -IdDict()), it will use use the `id_map` to avoid recomputing the same value -for the same node in a tree. Use this to automatically create functions that -work with trees that have shared child nodes. - -Can optionally take a `postprocess` function, which will be applied to the -result of the function before returning it, taking the result as the -first argument and a boolean for whether the result was memoized as the -second argument. This is useful for functions that need to count the number -of unique nodes in a tree, for example. -""" -macro memoize_on(tree, args...) - if length(args) ∉ (1, 2) - error("Expected 2 or 3 arguments to @memoize_on") - end - postprocess = length(args) == 1 ? :((r, _) -> r) : args[1] - def = length(args) == 1 ? args[1] : args[2] - idmap_def = _memoize_on(tree, postprocess, def) - - return quote - $(esc(def)) # The normal function - $(esc(idmap_def)) # The function with an id_map argument - end -end -function _memoize_on(tree::Symbol, postprocess, def) - sdef = splitdef(def) - - # Add an id_map argument - push!(sdef[:args], :(id_map::AbstractDict)) - - f_name = sdef[:name] - - # Forward id_map argument to all calls of the same function - # within the function body: - sdef[:body] = postwalk(sdef[:body]) do ex - if @capture(ex, f_(args__)) - if f == f_name - return Expr(:call, f, args..., :id_map) - end - end - return ex - end - - # Wrap the function body in a get!(id_map, tree) do ... end block: - @gensym key is_memoized result body - sdef[:body] = quote - $key = objectid($tree) - $is_memoized = haskey(id_map, $key) - function $body() - return $(sdef[:body]) - end - $result = if $is_memoized - @inbounds(id_map[$key]) - else - id_map[$key] = $body() - end - return $postprocess($result, $is_memoized) - end - - return combinedef(sdef) -end - -""" - @with_memoize(call, id_map) - -This simple macro simply puts the `id_map` -into the call, to be consistent with the `@memoize_on` macro. - -``` -@with_memoize(_copy_node(tree), IdDict{Any,Any}()) -```` - -is converted to - -``` -_copy_node(tree, IdDict{Any,Any}()) -``` - -""" -macro with_memoize(def, id_map) - idmap_def = _add_idmap_to_call(def, id_map) - return quote - $(esc(idmap_def)) - end -end - -function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr}) - @assert def.head == :call - return Expr(:call, def.args[1], def.args[2:end]..., id_map) -end - @inline function fill_similar(value::T, array, args...) where {T} out_array = similar(array, args...) fill!(out_array, value) diff --git a/src/Visualize.jl b/src/Visualize.jl new file mode 100644 index 00000000..1b3f65ce --- /dev/null +++ b/src/Visualize.jl @@ -0,0 +1,14 @@ +module VisualizeModule + +using ..NodeModule: GraphNode, Node +using ..OperatorEnumModule: AbstractOperatorEnum + +function visualize( + graph::Union{GraphNode,Node}, # types accepted by topological_sort + operators::AbstractOperatorEnum, + show = true +) + error("Please load the Plots.jl and GraphRecipes.jl packages to use this feature.") +end + +end \ No newline at end of file diff --git a/src/base.jl b/src/base.jl index b245a4ba..6519c57e 100644 --- a/src/base.jl +++ b/src/base.jl @@ -25,7 +25,7 @@ import Base: using DispatchDoctor: @unstable using Compat: @inline, Returns -using ..UtilsModule: @memoize_on, @with_memoize, Undefined +using ..UtilsModule: Undefined """ tree_mapreduce( @@ -89,38 +89,76 @@ function tree_mapreduce( f_leaf::F1, f_branch::F2, op::G, - tree::AbstractNode, + tree::AbstractNode{D}, result_type::Type{RT}=Undefined; f_on_shared::H=(result, is_shared) -> result, - break_sharing::Val=Val(false), -) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT} - - # Trick taken from here: - # https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5 - # to speed up recursive closure - @memoize_on t f_on_shared function inner(inner, t) - if t.degree == 0 - return @inline(f_leaf(t)) - elseif t.degree == 1 - return @inline(op(@inline(f_branch(t)), inner(inner, t.l))) - else - return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r))) - end - end - - sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false) + break_sharing::Val{BS}=Val(false), +) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS} + sharing = preserve_sharing(typeof(tree)) && !BS RT == Undefined && sharing && throw(ArgumentError("Need to specify `result_type` if nodes are shared..")) if sharing && RT != Undefined - d = allocate_id_map(tree, RT) - return @with_memoize inner(inner, tree) d + id_map = allocate_id_map(tree, RT) + reducer = TreeMapreducer(Val(D), id_map, f_leaf, f_branch, op, f_on_shared) + return reducer(tree) else - return inner(inner, tree) + reducer = TreeMapreducer(Val(D), nothing, f_leaf, f_branch, op, f_on_shared) + return reducer(tree) end end + +struct TreeMapreducer{D,ID,F1<:Function,F2<:Function,G<:Function,H<:Function} + max_degree::Val{D} + id_map::ID + f_leaf::F1 + f_branch::F2 + op::G + f_on_shared::H +end + +@generated function (mapreducer::TreeMapreducer{MAX_DEGREE,ID})( + tree::AbstractNode +) where {MAX_DEGREE,ID} + base_expr = quote + d = tree.degree + Base.Cartesian.@nif( + $(MAX_DEGREE + 1), + d_p_one -> (d_p_one - 1) == d, + d_p_one -> if d_p_one == 1 + mapreducer.f_leaf(tree) + else + mapreducer.op( + mapreducer.f_branch(tree), + Base.Cartesian.@ntuple( + d_p_one - 1, i -> mapreducer(tree.children[i][]) + )..., + ) + end + ) + end + if ID <: Nothing + # No sharing of nodes (is a tree, not a graph) + return base_expr + else + # Otherwise, we need to cache results in `id_map` + # according to `objectid` of the node + return quote + key = objectid(tree) + is_cached = haskey(mapreducer.id_map, key) + if is_cached + return mapreducer.f_on_shared(@inbounds(mapreducer.id_map[key]), true) + else + res = $base_expr + mapreducer.id_map[key] = res + return mapreducer.f_on_shared(res, false) + end + end + end +end + function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT} d = Dict{UInt,RT}() # Preallocate maximum storage (counting with duplicates is fast) @@ -128,7 +166,6 @@ function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT} sizehint!(d, N) return d end - # TODO: Raise Julia issue for this. # Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here! # I think it's because `setindex!` is declared with `@nospecialize` in IdDict. diff --git a/test/test_base.jl b/test/test_base.jl index b14894b1..f7e7a483 100644 --- a/test/test_base.jl +++ b/test/test_base.jl @@ -32,11 +32,11 @@ end @testset "collect" begin ctree = copy(tree) - @test typeof(first(collect(ctree))) == Node{Float64} + @test typeof(first(collect(ctree))) <: Node{Float64} @test objectid(first(collect(ctree))) == objectid(ctree) @test objectid(first(collect(ctree))) == objectid(ctree) @test objectid(first(collect(ctree))) == objectid(ctree) - @test typeof(collect(ctree)) == Vector{Node{Float64}} + @test typeof(collect(ctree)) <: Vector{<:Node{Float64}} @test length(collect(ctree)) == 24 @test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) ≈ 11.6 end diff --git a/test/test_custom_node_type.jl b/test/test_custom_node_type.jl index 3fc333bc..57a3706c 100644 --- a/test/test_custom_node_type.jl +++ b/test/test_custom_node_type.jl @@ -1,16 +1,21 @@ using DynamicExpressions using Test -mutable struct MyCustomNode{A,B} <: AbstractNode +mutable struct MyCustomNode{A,B} <: AbstractNode{2} degree::Int val1::A val2::B - l::MyCustomNode{A,B} - r::MyCustomNode{A,B} + children::NTuple{2,Base.RefValue{MyCustomNode{A,B}}} MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2) - MyCustomNode(val1, val2, l) = new{typeof(val1),typeof(val2)}(1, val1, val2, l) - MyCustomNode(val1, val2, l, r) = new{typeof(val1),typeof(val2)}(2, val1, val2, l, r) + function MyCustomNode(val1, val2, l) + return new{typeof(val1),typeof(val2)}( + 1, val1, val2, (Ref(l), Ref{MyCustomNode{typeof(val1),typeof(val2)}}()) + ) + end + function MyCustomNode(val1, val2, l, r) + return new{typeof(val1),typeof(val2)}(2, val1, val2, (Ref(l), Ref(r))) + end end node1 = MyCustomNode(1.0, 2) @@ -24,7 +29,7 @@ node2 = MyCustomNode(1.5, 3, node1) @test typeof(node2) == MyCustomNode{Float64,Int} @test node2.degree == 1 -@test node2.l.degree == 0 +@test node2.children[1][].degree == 0 @test count_depth(node2) == 2 @test count_nodes(node2) == 2 @@ -37,14 +42,13 @@ node2 = MyCustomNode(1.5, 3, node1, node1) @test count(t -> t.degree == 0, node2) == 2 # If we have a bad definition, it should get caught with a helpful message -mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T} +mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T,2} degree::UInt8 constant::Bool val::T feature::UInt16 op::UInt8 - l::MyCustomNode2{T} - r::MyCustomNode2{T} + children::NTuple{2,Base.RefValue{MyCustomNode2{T}}} end @test_throws ErrorException MyCustomNode2() diff --git a/test/test_equality.jl b/test/test_equality.jl index 220e63c3..7e9b845b 100644 --- a/test/test_equality.jl +++ b/test/test_equality.jl @@ -45,8 +45,8 @@ modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2) f64_tree = GraphNode{Float64}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) f32_tree = GraphNode{Float32}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) -@test typeof(f64_tree) == GraphNode{Float64} -@test typeof(f32_tree) == GraphNode{Float32} +@test typeof(f64_tree) <: GraphNode{Float64} +@test typeof(f32_tree) <: GraphNode{Float32} @test convert(GraphNode{Float64}, f32_tree) == f64_tree diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index 467c6226..60b35595 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -2,24 +2,31 @@ using Test using DynamicExpressions -using DynamicExpressions: constructorof +using DynamicExpressions: constructorof, max_degree -mutable struct FrozenNode{T} <: AbstractExpressionNode{T} +mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D} degree::UInt8 constant::Bool val::T frozen::Bool # Extra field! feature::UInt16 op::UInt8 - l::FrozenNode{T} - r::FrozenNode{T} + children::NTuple{D,Base.RefValue{FrozenNode{T,D}}} - function FrozenNode{_T}() where {_T} - n = new{_T}() + function FrozenNode{_T,_D}() where {_T,_D} + n = new{_T,_D}() n.frozen = false return n end end +function DynamicExpressions.constructorof(::Type{N}) where {N<:FrozenNode} + return FrozenNode{T,max_degree(N)} where {T} +end +function DynamicExpressions.with_type_parameters( + ::Type{N}, ::Type{T} +) where {T,N<:FrozenNode} + return FrozenNode{T,max_degree(N)} +end function DynamicExpressions.leaf_copy(t::FrozenNode{T}) where {T} out = if t.constant constructorof(typeof(t))(; val=t.val) @@ -56,7 +63,7 @@ function DynamicExpressions.leaf_equal(a::FrozenNode, b::FrozenNode) end end -n = let n = FrozenNode{Float64}() +n = let n = FrozenNode{Float64,2}() n.degree = 0 n.constant = true n.val = 0.0 @@ -92,5 +99,5 @@ ex = parse_expression( @test string_tree(ex) == "x + sin(y + 2.1)" @test ex.tree.frozen == false -@test ex.tree.r.frozen == true -@test ex.tree.r.l.frozen == false +@test ex.tree.children[2][].frozen == true +@test ex.tree.children[2][].children[1][].frozen == false diff --git a/test/test_graphs.jl b/test/test_graphs.jl index 2f31c4ed..55ab4d79 100644 --- a/test/test_graphs.jl +++ b/test/test_graphs.jl @@ -109,87 +109,6 @@ end :(_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())), ) end - - @testset "@with_memoize" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@with_memoize( - _convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}() - ) - true_ex = quote - _convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}()) - end - - @test expr_eql(ex, true_ex) - end - - @testset "@memoize_on" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@memoize_on tree ((x, _) -> x) function _copy_node( - tree::Node{T} - )::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - true_ex = quote - function _copy_node(tree::Node{T})::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - function _copy_node(tree::Node{T}, id_map::AbstractDict;)::Node{T} where {T} - key = objectid(tree) - is_memoized = haskey(id_map, key) - function body() - return begin - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l, id_map)) - else - Node( - copy(tree.op), - _copy_node(tree.l, id_map), - _copy_node(tree.r, id_map), - ) - end - end - end - result = if is_memoized - begin - $(Expr(:inbounds, true)) - local val = id_map[key] - $(Expr(:inbounds, :pop)) - val - end - else - id_map[key] = body() - end - return (((x, _) -> begin - x - end)(result, is_memoized)) - end - end - @test expr_eql(ex, true_ex) - end end @testset "Operations on graphs" begin @@ -353,7 +272,7 @@ end x = GraphNode(Float32; feature=1) tree = x + 1.0 @test tree.l === x - @test typeof(tree) === GraphNode{Float32} + @test typeof(tree) <: GraphNode{Float32} # Detect error from Float32(1im) @test_throws InexactError x + 1im diff --git a/test/test_parse.jl b/test/test_parse.jl index c9b40d0c..8d9c351d 100644 --- a/test/test_parse.jl +++ b/test/test_parse.jl @@ -108,7 +108,7 @@ end variable_names = ["x"], ) - @test typeof(ex.tree) === Node{Any} + @test typeof(ex.tree) <: Node{Any} @test typeof(ex.metadata.operators) <: GenericOperatorEnum s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex) @test s == "[1, 2, 3] * tan(cos(5.0 + x))" @@ -184,7 +184,7 @@ end s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex) @test s == "(x * 2.5) - cos(y)" end - @test contains(logged_out, "Node{Float32}") + @test contains(logged_out, "Node{Float32") end @testitem "Helpful errors for missing operator" begin