diff --git a/Project.toml b/Project.toml index 57a11c54..06bd0133 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,13 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "0.15.0" +version = "0.16.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" @@ -34,8 +32,8 @@ Bumper = "0.6" Compat = "3.37, 4" Enzyme = "^0.11.12" LoopVectorization = "0.12" -Optim = "0.19, 1" MacroTools = "0.4, 0.5" +Optim = "0.19, 1" PackageExtensionCompat = "1" PrecompileTools = "1" Reexport = "1" @@ -48,6 +46,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -58,4 +57,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"] +test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"] diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index bc6a8df5..98934133 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -23,7 +23,7 @@ function bumper_eval_tree_array( leaf_node -> begin ar = @alloc(T, n) ok = if leaf_node.constant - v = leaf_node.val::T + v = leaf_node.val ar .= v isfinite(v) else diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index 91c79dca..c9212e7c 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -41,8 +41,8 @@ function deg1_l2_ll0_lr0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true} ) where {T<:Number,F,F2} if tree.l.l.constant && tree.l.r.constant - val_ll = tree.l.l.val::T - val_lr = tree.l.r.val::T + val_ll = tree.l.l.val + val_lr = tree.l.r.val @return_on_check val_ll cX @return_on_check val_lr cX x_l = op_l(val_ll, val_lr)::T @@ -51,7 +51,7 @@ function deg1_l2_ll0_lr0_eval( @return_on_check x cX return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.l.constant - val_ll = tree.l.l.val::T + val_ll = tree.l.l.val @return_on_check val_ll cX feature_lr = tree.l.r.feature cumulator = similar(cX, axes(cX, 2)) @@ -63,7 +63,7 @@ function deg1_l2_ll0_lr0_eval( return ResultOk(cumulator, true) elseif tree.l.r.constant feature_ll = tree.l.l.feature - val_lr = tree.l.r.val::T + val_lr = tree.l.r.val @return_on_check val_lr cX cumulator = similar(cX, axes(cX, 2)) @turbo for j in axes(cX, 2) @@ -89,7 +89,7 @@ function deg1_l1_ll0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true} ) where {T<:Number,F,F2} if tree.l.l.constant - val_ll = tree.l.l.val::T + val_ll = tree.l.l.val @return_on_check val_ll cX x_l = op_l(val_ll)::T @return_on_check x_l cX @@ -112,16 +112,16 @@ function deg2_l0_r0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{true} ) where {T<:Number,F} if tree.l.constant && tree.r.constant - val_l = tree.l.val::T + val_l = tree.l.val @return_on_check val_l cX - val_r = tree.r.val::T + val_r = tree.r.val @return_on_check val_r cX x = op(val_l, val_r)::T @return_on_check x cX return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.constant cumulator = similar(cX, axes(cX, 2)) - val_l = tree.l.val::T + val_l = tree.l.val @return_on_check val_l cX feature_r = tree.r.feature @turbo for j in axes(cX, 2) @@ -132,7 +132,7 @@ function deg2_l0_r0_eval( elseif tree.r.constant cumulator = similar(cX, axes(cX, 2)) feature_l = tree.l.feature - val_r = tree.r.val::T + val_r = tree.r.val @return_on_check val_r cX @turbo for j in axes(cX, 2) x = op(cX[feature_l, j], val_r) @@ -160,7 +160,7 @@ function deg2_l0_eval( ::Val{true}, ) where {T<:Number,F} if tree.l.constant - val = tree.l.val::T + val = tree.l.val @return_on_check val cX @turbo for j in eachindex(cumulator) x = op(val, cumulator[j]) @@ -185,7 +185,7 @@ function deg2_r0_eval( ::Val{true}, ) where {T<:Number,F} if tree.r.constant - val = tree.r.val::T + val = tree.r.val @return_on_check val cX @turbo for j in eachindex(cumulator) x = op(cumulator[j], val) diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index 6773cb61..ff25aecf 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -34,7 +34,7 @@ function parse_tree_to_eqs( ) where {T} if tree.degree == 0 # Return constant if needed - tree.constant && return subs_bad(tree.val::T) + tree.constant && return subs_bad(tree.val) return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)")) end # Collect the next children @@ -93,10 +93,10 @@ function split_eq( else ind = findoperation(op, operators.binops) end - return constructorof(N)( - ind, - convert(N, args[1], operators; variable_names=variable_names), - convert(N, op(args[2:end]...), operators; variable_names=variable_names), + return constructorof(N)(; + op=ind, + l=convert(N, args[1], operators; variable_names=variable_names), + r=convert(N, op(args[2:end]...), operators; variable_names=variable_names), ) end @@ -157,8 +157,9 @@ function Base.convert( findoperation(op, operators.unaops) end - return constructorof(N)( - ind, map(x -> convert(N, x, operators; variable_names=variable_names), args)... + return constructorof(N)(; + op=ind, + children=map(x -> convert(N, x, operators; variable_names=variable_names), args), ) end diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index d46e8273..56a27bf0 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -5,6 +5,7 @@ include("ExtensionInterface.jl") include("OperatorEnum.jl") include("Equation.jl") include("EquationUtils.jl") +include("Strings.jl") include("EvaluateEquation.jl") include("EvaluateEquationDerivative.jl") include("EvaluationHelpers.jl") @@ -19,8 +20,6 @@ import Reexport: @reexport AbstractExpressionNode, GraphNode, Node, - string_tree, - print_tree, copy_node, set_node!, tree_mapreduce, @@ -39,6 +38,7 @@ import .EquationModule: constructorof, preserve_sharing set_constants!, get_constant_refs, set_constant_refs! +@reexport import .StringsModule: string_tree, print_tree @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! diff --git a/src/Equation.jl b/src/Equation.jl index d0bd4832..3c124cf5 100644 --- a/src/Equation.jl +++ b/src/Equation.jl @@ -79,44 +79,26 @@ nodes, you can evaluate or print a given expression. # Constructors -## Leafs - Node(; val=nothing, feature::Union{Integer,Nothing}=nothing) - Node{T}(; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T} + Node([T]; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) + Node{T}(; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) -Create a leaf node: either a constant, or a variable. +Create a new node in an expression tree. If `T` is not specified in either the type or the +first argument, it will be inferred from the value of `val` passed or `l` and/or `r`. +If it cannot be inferred from these, it will default to `Float32`. -- `::Type{T}`, optionally specify the type of the - node, if not already given by the type of - `val`. -- `val`, if you are specifying a constant, pass - the value of the constant here. -- `feature::Integer`, if you are specifying a variable, - pass the index of the variable here. +The `children` keyword can be used instead of `l` and `r` and should be a tuple of children. This +is to permit the use of splatting in constructors. -You can also create a leaf node from variable names: +You may also construct nodes via the convenience operators generated by creating an `OperatorEnum`. - Node(; var_string::String, variable_names::Array{String,1}) - Node{T}(; var_string::String, variable_names::Array{String,1}) where {T} - -## Unary operator - - Node(op::Integer, l::Node) - -Apply unary operator `op` (enumerating over the order given in `OperatorEnum`) -to `Node` `l`. - -## Binary operator - - Node(op::Integer, l::Node, r::Node) - -Apply binary operator `op` (enumerating over the order given in `OperatorEnum`) -to `Node`s `l` and `r`. +You may also choose to specify a default memory allocator for the node other than simply `Node{T}()` +in the `allocator` keyword argument. """ mutable struct Node{T} <: AbstractExpressionNode{T} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. constant::Bool # false if variable - val::Union{T,Nothing} # If is a constant, this stores the actual value + val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops @@ -126,12 +108,7 @@ mutable struct Node{T} <: AbstractExpressionNode{T} ################# ## Constructors: ################# - Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v) - Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v) - Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f)) - Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l) - Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r) - + Node{_T}() where {_T} = new{_T}() end """ @@ -168,26 +145,22 @@ when constructing or setting properties. mutable struct GraphNode{T} <: AbstractExpressionNode{T} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. constant::Bool # false if variable - val::Union{T,Nothing} # If is a constant, this stores the actual value + val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2. r::GraphNode{T} # Right child node. Only defined for degree=2. - ################# - ## Constructors: - ################# - GraphNode(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v) - GraphNode(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v) - GraphNode(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f)) - GraphNode(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::GraphNode{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l) - GraphNode(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::GraphNode{_T}, r::GraphNode{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r) + GraphNode{_T}() where {_T} = new{_T}() end ################################################################################ #! format: on +Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T +Base.eltype(::AbstractExpressionNode{T}) where {T} = T + constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper constructorof(::Type{<:Node}) = Node constructorof(::Type{<:GraphNode}) = GraphNode @@ -198,6 +171,12 @@ end with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T} with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T} +function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} + return with_type_parameters(N, T)() +end +default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}() +default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}() + """Trait declaring whether nodes share children or not.""" preserve_sharing(::Type{<:AbstractNode}) = false preserve_sharing(::Type{<:Node}) = false @@ -205,41 +184,89 @@ preserve_sharing(::Type{<:GraphNode}) = true include("base.jl") -function (::Type{N})( - ::Type{T}=Undefined; val::T1=nothing, feature::T2=nothing -) where {T,T1,T2<:Union{Integer,Nothing},N<:AbstractExpressionNode} - ((T1 <: Nothing) ⊻ (T2 <: Nothing)) || error( - "You must specify exactly one of `val` or `feature` when creating a leaf node." - ) - Tout = compute_value_output_type(N, T, T1) - if T2 <: Nothing - if !(T1 <: T) - # Only convert if not already in the type union. - val = convert(Tout, val) +#! format: off +@inline function (::Type{N})( + ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, +) where {T1,N<:AbstractExpressionNode,F} + if children !== nothing + @assert l === nothing && r === nothing + if length(children) == 1 + return node_factory(N, T1, val, feature, op, only(children), nothing, allocator) + else + return node_factory(N, T1, val, feature, op, children..., allocator) end - return constructorof(N)(Tout, 0, true, val) + end + return node_factory(N, T1, val, feature, op, l, r, allocator) +end +"""Create a constant leaf.""" +@inline function node_factory( + ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, ::Nothing, allocator::F, +) where {N,T1,T2,F} + T = node_factory_type(N, T1, T2) + n = allocator(N, T) + n.degree = 0 + n.constant = true + n.val = convert(T, val) + return n +end +"""Create a variable leaf, to store data.""" +@inline function node_factory( + ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, ::Nothing, allocator::F, +) where {N,T1,F} + T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) + n = allocator(N, T) + n.degree = 0 + n.constant = false + n.feature = feature + return n +end +"""Create a unary operator node.""" +@inline function node_factory( + ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, ::Nothing, allocator::F, +) where {N,T1,T2,F} + @assert l isa N + T = T2 # Always prefer existing nodes, so we don't mess up references from conversion + n = allocator(N, T) + n.degree = 1 + n.op = op + n.l = l + return n +end +"""Create a binary operator node.""" +@inline function node_factory( + ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, r::AbstractExpressionNode{T3}, allocator::F, +) where {N,T1,T2,T3,F} + T = promote_type(T2, T3) + n = allocator(N, T) + n.degree = 2 + n.op = op + n.l = T2 === T ? l : convert(with_type_parameters(N, T), l) + n.r = T3 === T ? r : convert(with_type_parameters(N, T), r) + return n +end + +@inline function node_factory_type(::Type{N}, ::Type{T1}, ::Type{T2}) where {N,T1,T2} + if T1 === Undefined && N isa UnionAll + T2 + elseif T1 === Undefined + eltype(N) + elseif N isa UnionAll + T1 else - return constructorof(N)(Tout, 0, false, nothing, feature) + eltype(N) end end +#! format: on + function (::Type{N})( - op::Integer, l::AbstractExpressionNode{T} -) where {T,N<:AbstractExpressionNode} - @assert l isa N - return constructorof(N)(1, false, nothing, 0, op, l) + op::Integer, l::AbstractExpressionNode +) where {N<:AbstractExpressionNode} + return N(; op=op, l=l) end function (::Type{N})( - op::Integer, l::AbstractExpressionNode{T1}, r::AbstractExpressionNode{T2} -) where {T1,T2,N<:AbstractExpressionNode} - @assert l isa N && r isa N - # Get highest type: - if T1 != T2 - T = promote_type(T1, T2) - # TODO: This might slow things down - l = convert(with_type_parameters(N, T), l) - r = convert(with_type_parameters(N, T), r) - end - return constructorof(N)(2, false, nothing, 0, op, l, r) + op::Integer, l::AbstractExpressionNode, r::AbstractExpressionNode +) where {N<:AbstractExpressionNode} + return N(; op=op, l=l, r=r) end function (::Type{N})(var_string::String) where {N<:AbstractExpressionNode} Base.depwarn( @@ -255,28 +282,6 @@ function (::Type{N})( return N(; feature=i) end -@inline function compute_value_output_type( - ::Type{N}, ::Type{T}, ::Type{T1} -) where {N<:AbstractExpressionNode,T,T1} - !(N isa UnionAll) && - T !== Undefined && - error( - "Ambiguous type for node. Please either use `Node{T}(; val, feature)` or `Node(T; val, feature)`.", - ) - - if T === Undefined && N isa UnionAll - if T1 <: Nothing - return DEFAULT_NODE_TYPE - else - return T1 - end - elseif T === Undefined - return eltype(N) - else - return T - end -end - function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2} return Node{promote_type(T1, T2)} end @@ -286,11 +291,9 @@ end function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2} return GraphNode{promote_type(T1, T2)} end -Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T -Base.eltype(::AbstractExpressionNode{T}) where {T} = T # TODO: Verify using this helps with garbage collection -create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N(; feature=zero(UInt16)) +create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N() """ set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T} @@ -324,181 +327,4 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod return nothing end -const OP_NAMES = Base.ImmutableDict( - "safe_log" => "log", - "safe_log2" => "log2", - "safe_log10" => "log10", - "safe_log1p" => "log1p", - "safe_acosh" => "acosh", - "safe_sqrt" => "sqrt", - "safe_pow" => "^", -) - -function dispatch_op_name(::Val{2}, ::Nothing, idx)::Vector{Char} - return vcat(collect("binary_operator["), collect(string(idx)), [']']) -end -function dispatch_op_name(::Val{1}, ::Nothing, idx)::Vector{Char} - return vcat(collect("unary_operator["), collect(string(idx)), [']']) -end -function dispatch_op_name(::Val{2}, operators::AbstractOperatorEnum, idx)::Vector{Char} - return get_op_name(operators.binops[idx]) -end -function dispatch_op_name(::Val{1}, operators::AbstractOperatorEnum, idx)::Vector{Char} - return get_op_name(operators.unaops[idx]) -end - -@generated function get_op_name(op::F)::Vector{Char} where {F} - try - # Bit faster to just cache the name of the operator: - op_s = string(F.instance) - out = collect(get(OP_NAMES, op_s, op_s)) - return :($out) - catch - end - return quote - op_s = string(op) - out = collect(get(OP_NAMES, op_s, op_s)) - return out - end -end - -@inline function strip_brackets(s::Vector{Char})::Vector{Char} - if first(s) == '(' && last(s) == ')' - return s[(begin + 1):(end - 1)] - else - return s - end -end - -# Can overload these for custom behavior: -needs_brackets(val::Real) = false -needs_brackets(val::AbstractArray) = false -needs_brackets(val::Complex) = true -needs_brackets(val) = true - -function string_constant(val) - if needs_brackets(val) - '(' * string(val) * ')' - else - string(val) - end -end - -function string_variable(feature, variable_names) - if variable_names === nothing || feature > lastindex(variable_names) - return 'x' * string(feature) - else - return variable_names[feature] - end -end - -# Vector of chars is faster than strings, so we use that. -function combine_op_with_inputs(op, l, r)::Vector{Char} - if first(op) in ('+', '-', '*', '/', '^') - # "(l op r)" - out = ['('] - append!(out, l) - push!(out, ' ') - append!(out, op) - push!(out, ' ') - append!(out, r) - push!(out, ')') - else - # "op(l, r)" - out = copy(op) - push!(out, '(') - append!(out, strip_brackets(l)) - push!(out, ',') - push!(out, ' ') - append!(out, strip_brackets(r)) - push!(out, ')') - return out - end -end -function combine_op_with_inputs(op, l) - # "op(l)" - out = copy(op) - push!(out, '(') - append!(out, strip_brackets(l)) - push!(out, ')') - return out -end - -""" - string_tree( - tree::AbstractExpressionNode{T}, - operators::Union{AbstractOperatorEnum,Nothing}=nothing; - f_variable::F1=string_variable, - f_constant::F2=string_constant, - variable_names::Union{Array{String,1},Nothing}=nothing, - # Deprecated - varMap=nothing, - )::String where {T,F1<:Function,F2<:Function} - -Convert an equation to a string. - -# Arguments -- `tree`: the tree to convert to a string -- `operators`: the operators used to define the tree - -# Keyword Arguments -- `f_variable`: (optional) function to convert a variable to a string, with arguments `(feature::UInt8, variable_names)`. -- `f_constant`: (optional) function to convert a constant to a string, with arguments `(val,)` -- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: (optional) what variables to print for each feature. -""" -function string_tree( - tree::AbstractExpressionNode{T}, - operators::Union{AbstractOperatorEnum,Nothing}=nothing; - f_variable::F1=string_variable, - f_constant::F2=string_constant, - variable_names::Union{Array{String,1},Nothing}=nothing, - # Deprecated - varMap=nothing, -)::String where {T,F1<:Function,F2<:Function} - variable_names = deprecate_varmap(variable_names, varMap, :string_tree) - raw_output = tree_mapreduce( - leaf -> if leaf.constant - collect(f_constant(leaf.val::T)) - else - collect(f_variable(leaf.feature, variable_names)) - end, - branch -> if branch.degree == 1 - dispatch_op_name(Val(1), operators, branch.op) - else - dispatch_op_name(Val(2), operators, branch.op) - end, - combine_op_with_inputs, - tree, - Vector{Char}; - f_on_shared=(c, is_shared) -> if is_shared - out = ['{'] - append!(out, c) - push!(out, '}') - out - else - c - end, - ) - return String(strip_brackets(raw_output)) -end - -# Print an equation -for io in ((), (:(io::IO),)) - @eval function print_tree( - $(io...), - tree::AbstractExpressionNode, - operators::Union{AbstractOperatorEnum,Nothing}=nothing; - f_variable::F1=string_variable, - f_constant::F2=string_constant, - variable_names::Union{Array{String,1},Nothing}=nothing, - # Deprecated - varMap=nothing, - ) where {F1<:Function,F2<:Function} - variable_names = deprecate_varmap(variable_names, varMap, :print_tree) - return println( - $(io...), string_tree(tree, operators; f_variable, f_constant, variable_names) - ) - end -end - end diff --git a/src/EquationUtils.jl b/src/EquationUtils.jl index db9990f7..b5c3fe43 100644 --- a/src/EquationUtils.jl +++ b/src/EquationUtils.jl @@ -76,7 +76,7 @@ The function `set_constants!` sets them in the same order, given the output of this function. """ function get_constants(tree::AbstractExpressionNode{T}) where {T} - return filter_map(is_node_constant, t -> (t.val::T), tree, T) + return filter_map(is_node_constant, t -> (t.val), tree, T) end """ @@ -91,7 +91,7 @@ function set_constants!( Base.require_one_based_indexing(constants) i = Ref(0) foreach(tree) do node - if node.degree == 0 && node.constant + if is_node_constant(node) @inbounds node.val = constants[i[] += 1] end end @@ -114,12 +114,12 @@ end function Base.getproperty(cr::NodeConstantRef{T}, s::Symbol) where {T} s != :x && error("Only :x is a valid property for NodeConstantRef") - return getfield(cr, :_node).x.val::T + return getfield(cr, :_node).x.val end function Base.setproperty!(cr::NodeConstantRef{T}, s::Symbol, v) where {T} s != :x && error("Only :x is a valid property for NodeConstantRef") - return getfield(cr, :_node).x.val::T = v::T + return getfield(cr, :_node).x.val = v end Base.propertynames(::NodeConstantRef) = (:x,) diff --git a/src/EvaluateEquation.jl b/src/EvaluateEquation.jl index f85b533b..bccade91 100644 --- a/src/EvaluateEquation.jl +++ b/src/EvaluateEquation.jl @@ -1,6 +1,7 @@ module EvaluateEquationModule -import ..EquationModule: AbstractExpressionNode, constructorof, string_tree +import ..EquationModule: AbstractExpressionNode, constructorof +import ..StringsModule: string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: is_bad_array, fill_similar, counttuple, ResultOk import ..EquationUtilsModule: is_constant @@ -150,7 +151,7 @@ function deg0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T} )::ResultOk where {T<:Number} if tree.constant - return ResultOk(fill_similar(tree.val::T, cX, axes(cX, 2)), true) + return ResultOk(fill_similar(tree.val, cX, axes(cX, 2)), true) else return ResultOk(cX[tree.feature, :], true) end @@ -302,8 +303,8 @@ function deg1_l2_ll0_lr0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{false} ) where {T<:Number,F,F2} if tree.l.l.constant && tree.l.r.constant - val_ll = tree.l.l.val::T - val_lr = tree.l.r.val::T + val_ll = tree.l.l.val + val_lr = tree.l.r.val @return_on_check val_ll cX @return_on_check val_lr cX x_l = op_l(val_ll, val_lr)::T @@ -312,7 +313,7 @@ function deg1_l2_ll0_lr0_eval( @return_on_check x cX return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.l.constant - val_ll = tree.l.l.val::T + val_ll = tree.l.l.val @return_on_check val_ll cX feature_lr = tree.l.r.feature cumulator = similar(cX, axes(cX, 2)) @@ -324,7 +325,7 @@ function deg1_l2_ll0_lr0_eval( return ResultOk(cumulator, true) elseif tree.l.r.constant feature_ll = tree.l.l.feature - val_lr = tree.l.r.val::T + val_lr = tree.l.r.val @return_on_check val_lr cX cumulator = similar(cX, axes(cX, 2)) @inbounds @simd for j in axes(cX, 2) @@ -351,7 +352,7 @@ function deg1_l1_ll0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{false} ) where {T<:Number,F,F2} if tree.l.l.constant - val_ll = tree.l.l.val::T + val_ll = tree.l.l.val @return_on_check val_ll cX x_l = op_l(val_ll)::T @return_on_check x_l cX @@ -375,16 +376,16 @@ function deg2_l0_r0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{false} ) where {T<:Number,F} if tree.l.constant && tree.r.constant - val_l = tree.l.val::T + val_l = tree.l.val @return_on_check val_l cX - val_r = tree.r.val::T + val_r = tree.r.val @return_on_check val_r cX x = op(val_l, val_r)::T @return_on_check x cX return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.constant cumulator = similar(cX, axes(cX, 2)) - val_l = tree.l.val::T + val_l = tree.l.val @return_on_check val_l cX feature_r = tree.r.feature @inbounds @simd for j in axes(cX, 2) @@ -395,7 +396,7 @@ function deg2_l0_r0_eval( elseif tree.r.constant cumulator = similar(cX, axes(cX, 2)) feature_l = tree.l.feature - val_r = tree.r.val::T + val_r = tree.r.val @return_on_check val_r cX @inbounds @simd for j in axes(cX, 2) x = op(cX[feature_l, j], val_r)::T @@ -423,7 +424,7 @@ function deg2_l0_eval( ::Val{false}, ) where {T<:Number,F} if tree.l.constant - val = tree.l.val::T + val = tree.l.val @return_on_check val cX @inbounds @simd for j in eachindex(cumulator) x = op(val, cumulator[j])::T @@ -449,7 +450,7 @@ function deg2_r0_eval( ::Val{false}, ) where {T<:Number,F} if tree.r.constant - val = tree.r.val::T + val = tree.r.val @return_on_check val cX @inbounds @simd for j in eachindex(cumulator) x = op(cumulator[j], val)::T @@ -522,7 +523,7 @@ over an entire array when the values are all the same. end @inline function deg0_eval_constant(tree::AbstractExpressionNode{T}) where {T<:Number} - output = tree.val::T + output = tree.val return ResultOk([output], true)::ResultOk{Vector{T}} end diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index 9cf7834e..0f923968 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -1,7 +1,8 @@ module OperatorEnumConstructionModule import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum -import ..EquationModule: string_tree, Node, GraphNode, AbstractExpressionNode, constructorof +import ..EquationModule: Node, GraphNode, AbstractExpressionNode, constructorof +import ..StringsModule: string_tree import ..EvaluateEquationModule: eval_tree_array, OPERATOR_LIMIT_BEFORE_SLOWDOWN import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient import ..EvaluationHelpersModule: _grad_evaluator @@ -121,10 +122,10 @@ function _extend_unary_operator(f::Symbol, type_requirements, internal) l::N ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} return if (l.degree == 0 && l.constant) - $_constructorof(N)(T; val=$($f)(l.val::T)) + $_constructorof(N)(T; val=$($f)(l.val)) else latest_op_idx = $($lookup_op)($($f), Val(1)) - $_constructorof(N)(latest_op_idx, l) + $_constructorof(N)(; op=latest_op_idx, l) end end end @@ -148,30 +149,34 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters, l::N, r::N ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} if (l.degree == 0 && l.constant && r.degree == 0 && r.constant) - $_constructorof(N)(T; val=$($f)(l.val::T, r.val::T)) + $_constructorof(N)(T; val=$($f)(l.val, r.val)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - $_constructorof(N)(latest_op_idx, l, r) + $_constructorof(N)(; op=latest_op_idx, l, r) end end function $($f)( l::N, r::T ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} if l.degree == 0 && l.constant - $_constructorof(N)(T; val=$($f)(l.val::T, r)) + $_constructorof(N)(T; val=$($f)(l.val, r)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - $_constructorof(N)(latest_op_idx, l, $_constructorof(N)(T; val=r)) + $_constructorof(N)(; + op=latest_op_idx, l, r=$_constructorof(N)(T; val=r) + ) end end function $($f)( l::T, r::N ) where {T<:$($type_requirements),N<:$_AbstractExpressionNode{T}} if r.degree == 0 && r.constant - $_constructorof(N)(T; val=$($f)(l, r.val::T)) + $_constructorof(N)(T; val=$($f)(l, r.val)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - $_constructorof(N)(latest_op_idx, $_constructorof(N)(T; val=l), r) + $_constructorof(N)(; + op=latest_op_idx, l=$_constructorof(N)(T; val=l), r + ) end end if $($build_converters) diff --git a/src/SimplifyEquation.jl b/src/SimplifyEquation.jl index ccd6bd80..77753695 100644 --- a/src/SimplifyEquation.jl +++ b/src/SimplifyEquation.jl @@ -33,30 +33,27 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where tree.r = combine_operators(tree.r, operators) end - top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant) + top_level_constant = + tree.degree == 2 && (is_node_constant(tree.l) || is_node_constant(tree.r)) if tree.degree == 2 && is_commutative(operators.binops[tree.op]) && top_level_constant # TODO: Does this break SymbolicRegression.jl due to the different names of operators? op = tree.op # Put the constant in r. Need to assume var in left for simplification assumption. - if tree.l.constant + if is_node_constant(tree.l) tmp = tree.r tree.r = tree.l tree.l = tmp end - topconstant = tree.r.val::T + topconstant = tree.r.val # Simplify down first below = tree.l if below.degree == 2 && below.op == op - if below.l.constant + if is_node_constant(below.l) tree = below - tree.l.val = _bin_op_kernel( - operators.binops[op], tree.l.val::T, topconstant - ) - elseif below.r.constant + tree.l.val = _bin_op_kernel(operators.binops[op], tree.l.val, topconstant) + elseif is_node_constant(below.r) tree = below - tree.r.val = _bin_op_kernel( - operators.binops[op], tree.r.val::T, topconstant - ) + tree.r.val = _bin_op_kernel(operators.binops[op], tree.r.val, topconstant) end end end @@ -65,40 +62,40 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where # Currently just simplifies subtraction. (can't assume both plus and sub are operators) # Not commutative, so use different op. - if tree.l.constant + if is_node_constant(tree.l) if tree.r.degree == 2 && tree.op == tree.r.op - if tree.r.l.constant + if is_node_constant(tree.r.l) #(const - (const - var)) => (var - const) l = tree.l r = tree.r - simplified_const = (r.l.val::T - l.val::T) #neg(sub(l.val, r.l.val)) + simplified_const = (r.l.val - l.val) #neg(sub(l.val, r.l.val)) tree.l = tree.r.r tree.r = l tree.r.val = simplified_const - elseif tree.r.r.constant + elseif is_node_constant(tree.r.r) #(const - (var - const)) => (const - var) l = tree.l r = tree.r - simplified_const = l.val::T + r.r.val::T #plus(l.val, r.r.val) + simplified_const = l.val + r.r.val #plus(l.val, r.r.val) tree.r = tree.r.l tree.l.val = simplified_const end end - else #tree.r.constant is true + else #tree.r is a constant if tree.l.degree == 2 && tree.op == tree.l.op - if tree.l.l.constant + if is_node_constant(tree.l.l) #((const - var) - const) => (const - var) l = tree.l r = tree.r - simplified_const = l.l.val::T - r.val::T#sub(l.l.val, r.val) + simplified_const = l.l.val - r.val#sub(l.l.val, r.val) tree.r = tree.l.r tree.l = r tree.l.val = simplified_const - elseif tree.l.r.constant + elseif is_node_constant(tree.l.r) #((var - const) - const) => (var - const) l = tree.l r = tree.r - simplified_const = r.val::T + l.r.val::T #plus(r.val, l.r.val) + simplified_const = r.val + l.r.val #plus(r.val, l.r.val) tree.l = tree.l.l tree.r.val = simplified_const end @@ -110,7 +107,7 @@ end function combine_children!(operators, p::N, c::N...) where {T,N<:AbstractExpressionNode{T}} all(is_node_constant, c) || return p - vals = map(n -> n.val::T, c) + vals = map(n -> n.val, c) all(isgood, vals) || return p out = if length(c) == 1 _una_op_kernel(operators.unaops[p.op], vals...) diff --git a/src/Strings.jl b/src/Strings.jl new file mode 100644 index 00000000..8fafb105 --- /dev/null +++ b/src/Strings.jl @@ -0,0 +1,184 @@ +module StringsModule + +using ..UtilsModule: deprecate_varmap +using ..OperatorEnumModule: AbstractOperatorEnum +using ..EquationModule: AbstractExpressionNode, tree_mapreduce + +const OP_NAMES = Base.ImmutableDict( + "safe_log" => "log", + "safe_log2" => "log2", + "safe_log10" => "log10", + "safe_log1p" => "log1p", + "safe_acosh" => "acosh", + "safe_sqrt" => "sqrt", + "safe_pow" => "^", +) + +function dispatch_op_name(::Val{2}, ::Nothing, idx)::Vector{Char} + return vcat(collect("binary_operator["), collect(string(idx)), [']']) +end +function dispatch_op_name(::Val{1}, ::Nothing, idx)::Vector{Char} + return vcat(collect("unary_operator["), collect(string(idx)), [']']) +end +function dispatch_op_name(::Val{2}, operators::AbstractOperatorEnum, idx)::Vector{Char} + return get_op_name(operators.binops[idx]) +end +function dispatch_op_name(::Val{1}, operators::AbstractOperatorEnum, idx)::Vector{Char} + return get_op_name(operators.unaops[idx]) +end + +@generated function get_op_name(op::F)::Vector{Char} where {F} + try + # Bit faster to just cache the name of the operator: + op_s = string(F.instance) + out = collect(get(OP_NAMES, op_s, op_s)) + return :($out) + catch + end + return quote + op_s = string(op) + out = collect(get(OP_NAMES, op_s, op_s)) + return out + end +end + +@inline function strip_brackets(s::Vector{Char})::Vector{Char} + if first(s) == '(' && last(s) == ')' + return s[(begin + 1):(end - 1)] + else + return s + end +end + +# Can overload these for custom behavior: +needs_brackets(val::Real) = false +needs_brackets(val::AbstractArray) = false +needs_brackets(val::Complex) = true +needs_brackets(val) = true + +function string_constant(val) + if needs_brackets(val) + '(' * string(val) * ')' + else + string(val) + end +end + +function string_variable(feature, variable_names) + if variable_names === nothing || feature > lastindex(variable_names) + return 'x' * string(feature) + else + return variable_names[feature] + end +end + +# Vector of chars is faster than strings, so we use that. +function combine_op_with_inputs(op, l, r)::Vector{Char} + if first(op) in ('+', '-', '*', '/', '^') + # "(l op r)" + out = ['('] + append!(out, l) + push!(out, ' ') + append!(out, op) + push!(out, ' ') + append!(out, r) + push!(out, ')') + else + # "op(l, r)" + out = copy(op) + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ',') + push!(out, ' ') + append!(out, strip_brackets(r)) + push!(out, ')') + return out + end +end +function combine_op_with_inputs(op, l) + # "op(l)" + out = copy(op) + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ')') + return out +end + +""" + string_tree( + tree::AbstractExpressionNode{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + f_variable::F1=string_variable, + f_constant::F2=string_constant, + variable_names::Union{Array{String,1},Nothing}=nothing, + # Deprecated + varMap=nothing, + )::String where {T,F1<:Function,F2<:Function} + +Convert an equation to a string. + +# Arguments +- `tree`: the tree to convert to a string +- `operators`: the operators used to define the tree + +# Keyword Arguments +- `f_variable`: (optional) function to convert a variable to a string, with arguments `(feature::UInt8, variable_names)`. +- `f_constant`: (optional) function to convert a constant to a string, with arguments `(val,)` +- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: (optional) what variables to print for each feature. +""" +function string_tree( + tree::AbstractExpressionNode{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + f_variable::F1=string_variable, + f_constant::F2=string_constant, + variable_names::Union{Array{String,1},Nothing}=nothing, + # Deprecated + varMap=nothing, +)::String where {T,F1<:Function,F2<:Function} + variable_names = deprecate_varmap(variable_names, varMap, :string_tree) + raw_output = tree_mapreduce( + leaf -> if leaf.constant + collect(f_constant(leaf.val)) + else + collect(f_variable(leaf.feature, variable_names)) + end, + branch -> if branch.degree == 1 + dispatch_op_name(Val(1), operators, branch.op) + else + dispatch_op_name(Val(2), operators, branch.op) + end, + combine_op_with_inputs, + tree, + Vector{Char}; + f_on_shared=(c, is_shared) -> if is_shared + out = ['{'] + append!(out, c) + push!(out, '}') + out + else + c + end, + ) + return String(strip_brackets(raw_output)) +end + +# Print an equation +for io in ((), (:(io::IO),)) + @eval function print_tree( + $(io...), + tree::AbstractExpressionNode, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + f_variable::F1=string_variable, + f_constant::F2=string_constant, + variable_names::Union{Array{String,1},Nothing}=nothing, + # Deprecated + varMap=nothing, + ) where {F1<:Function,F2<:Function} + variable_names = deprecate_varmap(variable_names, varMap, :print_tree) + return println( + $(io...), string_tree(tree, operators; f_variable, f_constant, variable_names) + ) + end +end + +end diff --git a/src/base.jl b/src/base.jl index 53da7d3f..2d315ba7 100644 --- a/src/base.jl +++ b/src/base.jl @@ -411,7 +411,7 @@ function hash( tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val=Val(false) ) where {T} return tree_mapreduce( - t -> t.constant ? hash((0, t.val::T), h) : hash((1, t.feature), h), + t -> t.constant ? hash((0, t.val), h) : hash((1, t.feature), h), t -> hash((t.degree + 1, t.op), h), (n...) -> hash(n, h), tree, @@ -435,12 +435,12 @@ function copy_node( ) where {T,N<:AbstractExpressionNode{T}} return tree_mapreduce( t -> if t.constant - constructorof(N)(; val=t.val::T) + constructorof(N)(; val=t.val) else constructorof(N)(T; feature=t.feature) end, identity, - (p, c...) -> constructorof(N)(p.op, c...), + (p, children...) -> constructorof(N)(; op=p.op, children), tree, N; break_sharing, @@ -478,12 +478,12 @@ function convert( end return tree_mapreduce( t -> if t.constant - constructorof(N1)(T1, 0, true, convert(T1, t.val::T2)) + constructorof(N1)(; val=convert(T1, t.val::T2)) else - constructorof(N1)(T1, 0, false, nothing, t.feature) + constructorof(N1)(T1; feature=t.feature) end, identity, - (p, c...) -> constructorof(N1)(p.degree, false, nothing, 0, p.op, c...), + (p, children...) -> constructorof(N1)(; op=p.op, children), tree, N1, ) @@ -491,7 +491,7 @@ end function convert( ::Type{N1}, tree::N2 ) where {T2,N1<:AbstractExpressionNode,N2<:AbstractExpressionNode{T2}} - return convert(constructorof(N1){T2}, tree) + return convert(with_type_parameters(N1, T2), tree) end function (::Type{N})(tree::AbstractExpressionNode) where {N<:AbstractExpressionNode} return convert(N, tree) diff --git a/src/deprecated.jl b/src/deprecated.jl index de03e91a..f62df171 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,4 +1,71 @@ import Base: @deprecate +import .EquationModule: Node, GraphNode @deprecate set_constants set_constants! @deprecate simplify_tree simplify_tree! + +for N in (:Node, :GraphNode) + @eval begin + function $N(d::Integer, c::Bool, v::T) where {T} + Base.depwarn( + string($N) * + "(d, c, v) is deprecated. Use " * + string($(N)) * + "{T}(val=v) instead.", + $(Meta.quot(N)), + ) + @assert d == 1 + @assert c == true + return $N{T}(; val=v) + end + function $N(::Type{T}, d::Integer, c::Bool, v::_T) where {T,_T} + Base.depwarn( + string($N) * + "(T, d, c, v) is deprecated. Use " * + string($(N)) * + "{T}(val=v) instead.", + $(Meta.quot(N)), + ) + @assert d == 1 + @assert c == true + return $N{T}(; val=v) + end + function $N(::Type{T}, d::Integer, c::Bool, ::Nothing, f::Integer) where {T} + Base.depwarn( + string($N) * + "(T, d, c, v, f) is deprecated. Use " * + string($(N)) * + "{T}(feature=f) instead.", + $(Meta.quot(N)), + ) + + @assert d == 1 + @assert c == false + return $N{T}(; feature=f) + end + function $N(d::Integer, ::Bool, ::Nothing, ::Integer, o::Integer, l::$N) + Base.depwarn( + string($N) * + "(d, c, v, f, o, l) is deprecated. Use " * + string($(N)) * + "(op=o, l=l) instead.", + $(Meta.quot(N)), + ) + @assert d == 1 + return $N(; op=o, l=l) + end + function $N( + d::Integer, ::Bool, ::Nothing, ::Integer, o::Integer, l::$N{T}, r::$N{T} + ) where {T} + Base.depwarn( + string($N) * + "(d, c, v, f, o, l, r) is deprecated. Use " * + string($(N)) * + "(op=o, l=l, r=r) instead.", + $(Meta.quot(N)), + ) + @assert d == 2 + return $N(; op=o, l=l, r=r) + end + end +end diff --git a/test/test_base.jl b/test/test_base.jl index cf51b717..aaf0032d 100644 --- a/test/test_base.jl +++ b/test/test_base.jl @@ -73,7 +73,7 @@ end t.val *= 2 end end - @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2 + @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) ≈ 11.6 * 2 end @testset "iterate" begin @@ -88,7 +88,7 @@ end t.val *= 2 end end - @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2 + @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) ≈ 11.6 * 2 # iterate within iterate: counter = Ref(0) @@ -110,7 +110,7 @@ end @test sum(map(t -> t.degree == 1, ctree)) == 1 @test length(unique(map(objectid, copy_node(tree)))) == 24 map(t -> (t.degree == 0 && t.constant) ? (t.val *= 2) : nothing, ctree) - @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2 + @test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) ≈ 11.6 * 2 local T = fieldtype(typeof(ctree), :degree) @test typeof(map(t -> t.degree, ctree, T)) == Vector{T} @test first(map(t -> t.degree, ctree, T)) == 2 diff --git a/test/test_deprecations.jl b/test/test_deprecations.jl index 202d76a9..baef1d0b 100644 --- a/test/test_deprecations.jl +++ b/test/test_deprecations.jl @@ -20,3 +20,26 @@ for constructor in (OperatorEnum, GenericOperatorEnum) [1.0; 2.0;;] ) end + +if VERSION >= v"1.9" + @test_logs (:warn, r"Node\(d, c, v\) is deprecated.*") ( + n = Node(1, true, 1.0 + 0im); @assert (n.val isa ComplexF64) + ) + @test_logs (:warn, r"Node\(T, d, c, v\) is deprecated.*") ( + n = Node(Float32, 1, true, 1.0 + 0im); @assert (n.val isa Float32) + ) + @test_logs (:warn, r"Node\(T, d, c, v, f\) is deprecated.*") ( + n = Node(Float32, 1, false, nothing, 1); @assert (n.feature == 1) + ) + @test_logs (:warn, r"Node\(d, c, v, f, o, l\) is deprecated.*") ( + x1 = Node(; feature=1); + n = Node(1, true, nothing, 1, 3, x1); + @assert (n.op == 3 && n.l === x1) + ) + @test_logs (:warn, r"Node\(d, c, v, f, o, l, r\) is deprecated.*") ( + x1 = Node(; feature=1); + x2 = Node(; feature=2); + n = Node(2, true, nothing, 1, 1, x1, x2); + @assert (n.op == 1 && n.l === x1 && n.r === x2) + ) +end diff --git a/test/test_graphs.jl b/test/test_graphs.jl index ec4e0dcd..ac5c1fa3 100644 --- a/test/test_graphs.jl +++ b/test/test_graphs.jl @@ -127,7 +127,7 @@ end )::Node{T} where {T} if tree.degree == 0 if tree.constant - Node(; val=copy(tree.val::T)) + Node(; val=copy(tree.val)) else Node(T; feature=copy(tree.feature)) end @@ -141,7 +141,7 @@ end function _copy_node(tree::Node{T})::Node{T} where {T} if tree.degree == 0 if tree.constant - Node(; val=copy(tree.val::T)) + Node(; val=copy(tree.val)) else Node(T; feature=copy(tree.feature)) end @@ -158,7 +158,7 @@ end return begin if tree.degree == 0 if tree.constant - Node(; val=copy(tree.val::T)) + Node(; val=copy(tree.val)) else Node(T; feature=copy(tree.feature)) end diff --git a/test/test_simplification.jl b/test/test_simplification.jl index c8841458..2f0df9c0 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -1,6 +1,6 @@ include("test_params.jl") using DynamicExpressions, Test -import DynamicExpressions.EquationModule: strip_brackets +import DynamicExpressions.StringsModule: strip_brackets import SymbolicUtils: simplify, Symbolic import Random: MersenneTwister import Base: ≈ diff --git a/test/test_tree_construction.jl b/test/test_tree_construction.jl index e1210cee..60899008 100644 --- a/test/test_tree_construction.jl +++ b/test/test_tree_construction.jl @@ -106,6 +106,17 @@ end @test repr(tree) == repr(tree2) end +@testset "Type inference" begin + @inferred Node(; feature=1) + @inferred Node(; val=1) + @inferred Node(Float32; val=1) + @inferred Node{Float32}(; val=1) + x1 = Node{Float32}(; feature=1) + @inferred Node(; op=1, l=x1) + @inferred Node(; op=1, l=x1, r=x1) + @inferred Node(; op=1, l=x1, r=Node{Float64}(x1)) +end + @testset "Miscellaneous" begin operators = OperatorEnum(; default_params...) for N in (Node, GraphNode)