diff --git a/Project.toml b/Project.toml index ad54c11d..eac9803c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "0.18.0-alpha.1" +version = "0.18.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 64deaec3..0eebeadc 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -66,7 +66,8 @@ import .NodeModule: @reexport import .EvaluationHelpersModule @reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node @reexport import .RandomModule: NodeSampler -@reexport import .ExpressionModule: AbstractExpression, Expression, with_tree +@reexport import .ExpressionModule: + AbstractExpression, Expression, with_contents, with_metadata, get_contents, get_metadata import .ExpressionModule: get_tree, get_operators, get_variable_names, Metadata, default_node_type, node_type @reexport import .ParseModule: @parse_expression, parse_expression diff --git a/src/Expression.jl b/src/Expression.jl index 7113337d..986b85de 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -6,6 +6,7 @@ using ..NodeModule: AbstractExpressionNode, Node using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum using ..UtilsModule: Undefined +import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof import ..NodeUtilsModule: preserve_sharing, count_constants, @@ -85,7 +86,7 @@ end end node_type(::Union{E,Type{E}}) where {N,E<:AbstractExpression{<:Any,N}} = N -@unstable default_node_type(::Type{<:AbstractExpression}) = Node +@unstable default_node_type(_) = Node default_node_type(::Type{<:AbstractExpression{T}}) where {T} = Node{T} ######################################################## @@ -128,28 +129,52 @@ end function Base.copy(ex::AbstractExpression; break_sharing::Val=Val(false)) return error("`copy` function must be implemented for $(typeof(ex)) types.") end -function Base.hash(ex::AbstractExpression, h::UInt) - return error("`hash` function must be implemented for $(typeof(ex)) types.") -end -function Base.:(==)(x::AbstractExpression, y::AbstractExpression) - return error("`==` function must be implemented for $(typeof(x)) types.") -end function get_constants(ex::AbstractExpression) return error("`get_constants` function must be implemented for $(typeof(ex)) types.") end function set_constants!(ex::AbstractExpression{T}, constants, refs) where {T} return error("`set_constants!` function must be implemented for $(typeof(ex)) types.") end +function get_contents(ex::AbstractExpression) + return error("`get_contents` function must be implemented for $(typeof(ex)) types.") +end +function get_metadata(ex::AbstractExpression) + return error("`get_metadata` function must be implemented for $(typeof(ex)) types.") +end ######################################################## """ - with_tree(ex::AbstractExpression, tree::AbstractExpressionNode) + with_contents(ex::AbstractExpression, tree::AbstractExpressionNode) + with_contents(ex::AbstractExpression, tree::AbstractExpression) Create a new expression based on `ex` but with a different `tree` """ -function with_tree(ex::AbstractExpression, tree) - return constructorof(typeof(ex))(tree, ex.metadata) +function with_contents(ex::AbstractExpression, tree::AbstractExpression) + return with_contents(ex, get_contents(tree)) +end +function with_contents(ex::AbstractExpression, tree) + return constructorof(typeof(ex))(tree, get_metadata(ex)) +end +function get_contents(ex::Expression) + return ex.tree +end + +""" + with_metadata(ex::AbstractExpression, metadata) + with_metadata(ex::AbstractExpression; metadata...) + +Create a new expression based on `ex` but with a different `metadata`. +""" +function with_metadata(ex::AbstractExpression; metadata...) + return with_metadata(ex, Metadata((; metadata...))) +end +function with_metadata(ex::AbstractExpression, metadata::Metadata) + return constructorof(typeof(ex))(get_contents(ex), metadata) +end +function get_metadata(ex::Expression) + return ex.metadata end + function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T,N}} return preserve_sharing(N) end @@ -169,25 +194,17 @@ end function Base.copy(ex::Expression; break_sharing::Val=Val(false)) return Expression(copy(ex.tree; break_sharing), copy(ex.metadata)) end -function Base.hash(ex::Expression, h::UInt) - return hash(ex.tree, hash(ex.metadata, h)) +function Base.hash(ex::AbstractExpression, h::UInt) + return hash(get_contents(ex), hash(get_metadata(ex), h)) end - -""" - Base.:(==)(x::Expression, y::Expression) - -Check equality of two expressions `x` and `y` by comparing their trees and metadata. -""" -function Base.:(==)(x::Expression, y::Expression) - return x.tree == y.tree && x.metadata == y.metadata +function Base.:(==)(x::AbstractExpression, y::AbstractExpression) + return get_contents(x) == get_contents(y) && get_metadata(x) == get_metadata(y) end # Overload all methods on AbstractExpressionNode that return an aggregation, or can # return an entire tree. Methods that only return the nodes are *not* overloaded, so # that the user must use the low-level interface. -import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof - #! format: off @unstable constructorof(::Type{E}) where {E<:AbstractExpression} = Base.typename(E).wrapper @unstable constructorof(::Type{<:Expression}) = Expression diff --git a/src/Interfaces.jl b/src/Interfaces.jl index c6b0a467..750aead6 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -43,7 +43,10 @@ using ..ExpressionModule: get_tree, get_operators, get_variable_names, - with_tree, + get_contents, + get_metadata, + with_contents, + with_metadata, default_node_type using ..ParametricExpressionModule: ParametricExpression, ParametricNode @@ -52,6 +55,14 @@ using ..ParametricExpressionModule: ParametricExpression, ParametricNode ############################################################################### ## mandatory +function _check_get_contents(ex::AbstractExpression) + new_ex = with_contents(ex, get_contents(ex)) + return new_ex == ex && new_ex isa typeof(ex) +end +function _check_get_metadata(ex::AbstractExpression) + new_ex = with_metadata(ex, get_metadata(ex)) + return new_ex == ex && new_ex isa typeof(ex) +end function _check_get_tree(ex::AbstractExpression{T,N}) where {T,N} return get_tree(ex) isa N end @@ -67,6 +78,15 @@ function _check_copy(ex::AbstractExpression) # TODO: Could include checks for aliasing here return preserves end +function _check_with_contents(ex::AbstractExpression) + new_ex = with_contents(ex, get_contents(ex)) + new_ex2 = with_contents(ex, ex) + return new_ex == ex && new_ex isa typeof(ex) && new_ex2 == ex && new_ex2 isa typeof(ex) +end +function _check_with_metadata(ex::AbstractExpression) + new_ex = with_metadata(ex, get_metadata(ex)) + return new_ex == ex && new_ex isa typeof(ex) +end ## optional function _check_count_nodes(ex::AbstractExpression) @@ -116,10 +136,14 @@ end #! format: off ei_components = ( mandatory = ( + get_contents = "extracts the runtime contents of an expression" => _check_get_contents, + get_metadata = "extracts the runtime metadata of an expression" => _check_get_metadata, get_tree = "extracts the expression tree from [`AbstractExpression`](@ref)" => _check_get_tree, get_operators = "returns the operators used in the expression (or pass `operators` explicitly to override)" => _check_get_operators, get_variable_names = "returns the variable names used in the expression (or pass `variable_names` explicitly to override)" => _check_get_variable_names, copy = "returns a copy of the expression" => _check_copy, + with_contents = "returns the expression with different tree" => _check_with_contents, + with_metadata = "returns the expression with different metadata" => _check_with_metadata, ), optional = ( count_nodes = "counts the number of nodes in the expression tree" => _check_count_nodes, diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index d0692cc5..d4ecac10 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -19,7 +19,13 @@ import ..EvaluateModule: eval_tree_array import ..EvaluateDerivativeModule: eval_grad_tree_array import ..EvaluationHelpersModule: _grad_evaluator import ..ExpressionModule: - get_tree, get_operators, get_variable_names, max_feature, default_node_type + get_contents, + get_metadata, + get_tree, + get_operators, + get_variable_names, + max_feature, + default_node_type import ..ParseModule: parse_leaf """A type of expression node that also stores a parameter index""" @@ -48,9 +54,13 @@ end """ ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N} -An expression to store parameters for a tree +(Experimental) An expression to store parameters for a tree """ -struct ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N} +struct ParametricExpression{ + T, + N<:ParametricNode{T}, + D<:NamedTuple{(:operators, :variable_names, :parameters, :parameter_names)}, +} <: AbstractExpression{T,N} tree::N metadata::Metadata{D} @@ -65,8 +75,9 @@ function ParametricExpression( parameters::AbstractMatrix{T2}, parameter_names, ) where {T1,T2} - @assert (isempty(parameters) && isnothing(parameter_names)) || - size(parameters, 1) == length(parameter_names) + if !isnothing(parameter_names) + @assert size(parameters, 1) == length(parameter_names) + end T = promote_type(T1, T2) t = T === T1 ? tree : convert(ParametricNode{T}, tree) m = Metadata((; @@ -127,9 +138,9 @@ end ############################################################################### # Abstract expression interface ############################################### ############################################################################### -function get_tree(ex::ParametricExpression) - return ex.tree -end +get_contents(ex::ParametricExpression) = ex.tree +get_metadata(ex::ParametricExpression) = ex.metadata +get_tree(ex::ParametricExpression) = ex.tree function get_operators(ex::ParametricExpression, operators=nothing) return operators === nothing ? ex.metadata.operators : operators end @@ -147,12 +158,6 @@ function Base.copy(ex::ParametricExpression; break_sharing::Val=Val(false)) parameter_names=_copy_with_nothing(ex.metadata.parameter_names), ) end -function Base.hash(ex::ParametricExpression, h::UInt) - return hash(ex.tree, hash(ex.metadata, h)) -end -function Base.:(==)(x::ParametricExpression, y::ParametricExpression) - return x.tree == y.tree && x.metadata == y.metadata -end ############################################################################### ############################################################################### @@ -283,10 +288,16 @@ function string_tree( UInt16(0) end end + _parameter_names = ex.metadata.parameter_names + parameter_names = if _parameter_names === nothing + ["p$(i)" for i in 1:num_params] + else + _parameter_names + end variable_names3 = if variable_names2 === nothing - vcat(["p$(i)" for i in 1:num_params], ["x$(i)" for i in 1:max_feature]) + vcat(parameter_names, ["x$(i)" for i in 1:max_feature]) else - vcat(ex.metadata.parameter_names, variable_names2) + vcat(parameter_names, variable_names2) end @assert length(variable_names3) >= num_params + max_feature return string_tree( diff --git a/test/test_expressions.jl b/test/test_expressions.jl index d5ffc259..75fede26 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -168,14 +168,14 @@ end @test has_constants(ex) == false end -@testitem "Expression with_tree" begin +@testitem "Expression with_contents" begin using DynamicExpressions ex = @parse_expression(x1 + 1.5, binary_operators = [+, *], variable_names = ["x1"]) ex2 = @parse_expression(x1 + 3.0, binary_operators = [+], variable_names = ["x1"]) - t2 = DynamicExpressions.get_tree(ex2) - ex_modified = DynamicExpressions.with_tree(ex, t2) + t2 = DynamicExpressions.get_contents(ex2) + ex_modified = DynamicExpressions.with_contents(ex, t2) @test DynamicExpressions.get_tree(ex_modified) == t2 end diff --git a/test/test_multi_expression.jl b/test/test_multi_expression.jl index e1811ba9..56ee6d0e 100644 --- a/test/test_multi_expression.jl +++ b/test/test_multi_expression.jl @@ -10,6 +10,13 @@ trees::TREES metadata::Metadata{D} + function MultiScalarExpression(trees::NamedTuple, metadata::Metadata{D}) where {D} + example_tree = first(values(trees)) + N = typeof(example_tree) + T = eltype(example_tree) + return new{T,N,typeof(trees),D}(trees, metadata) + end + """ Create a multi-expression expression type. @@ -54,8 +61,6 @@ ) @test_throws "`get_tree` function must be implemented for" DE.get_tree(multi_ex) @test_throws "`copy` function must be implemented for" copy(multi_ex) - @test_throws "`hash` function must be implemented for" hash(multi_ex, UInt(0)) - @test_throws "`==` function must be implemented for" multi_ex == multi_ex @test_throws "`get_constants` function must be implemented for" get_constants( multi_ex ) @@ -65,6 +70,12 @@ end tree_factory(f::F, trees) where {F} = f(; trees...) + function DE.get_contents(ex::MultiScalarExpression) + return ex.trees + end + function DE.get_metadata(ex::MultiScalarExpression) + return ex.metadata + end function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N} fused_expression = parse_expression( tree_factory(ex.metadata.tree_factory, ex.trees)::Expr;