From 1d86cc207a3683d5d086a345b05e503d4a447da2 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 24 Jun 2024 23:22:54 +0100 Subject: [PATCH 1/3] fix: generalize combine_operators --- src/ParametricExpression.jl | 1 - src/PatchMethods.jl | 9 ++++----- src/Simplify.jl | 3 ++- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 09ce8aa3..383af13d 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -15,7 +15,6 @@ import ..NodeUtilsModule: get_constants, set_constants! import ..StringsModule: string_tree -import ..SimplifyModule: combine_operators, simplify_tree! import ..EvaluateModule: eval_tree_array import ..EvaluateDerivativeModule: eval_grad_tree_array import ..EvaluationHelpersModule: _grad_evaluator diff --git a/src/PatchMethods.jl b/src/PatchMethods.jl index 62023fa2..525a8cc2 100644 --- a/src/PatchMethods.jl +++ b/src/PatchMethods.jl @@ -1,5 +1,6 @@ module PatchMethodsModule +using DynamicExpressions: get_contents, with_contents using ..OperatorEnumModule: AbstractOperatorEnum using ..NodeModule: constructorof using ..ExpressionModule: Expression, get_tree, get_operators @@ -11,17 +12,15 @@ function combine_operators( ex::Union{Expression{T,N},ParametricExpression{T,N}}, operators::Union{AbstractOperatorEnum,Nothing}=nothing, ) where {T,N} - return constructorof(typeof(ex))( - combine_operators(get_tree(ex)::N, get_operators(ex, operators)), ex.metadata + return with_contents( + ex, combine_operators(get_contents(ex), get_operators(ex, operators)) ) end function simplify_tree!( ex::Union{Expression{T,N},ParametricExpression{T,N}}, operators::Union{AbstractOperatorEnum,Nothing}=nothing, ) where {T,N} - return constructorof(typeof(ex))( - simplify_tree!(get_tree(ex)::N, get_operators(ex, operators)), ex.metadata - ) + return with_contents(ex, simplify_tree!(get_contents(ex), get_operators(ex, operators))) end end diff --git a/src/Simplify.jl b/src/Simplify.jl index 64fc174d..1fff2dbd 100644 --- a/src/Simplify.jl +++ b/src/Simplify.jl @@ -15,7 +15,8 @@ is_commutative(_) = false is_subtraction(::typeof(-)) = true is_subtraction(_) = false -# This is only defined for `Node` as it is not possible for +combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree +# This is only defined for `Node` as it is not possible for, e.g., # `GraphNode`. function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T} # NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before. From bf6c4032c224d2a4e8a9f672819d05c185d53137 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 24 Jun 2024 23:27:43 +0100 Subject: [PATCH 2/3] feat: add `is_node_constant` to interface --- src/Interfaces.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 750aead6..5ab14d7f 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -26,6 +26,7 @@ using ..NodeModule: filter_map! using ..NodeUtilsModule: NodeIndex, + is_node_constant, count_constants, count_depth, index_constants, @@ -273,6 +274,9 @@ end function _check_count_depth(tree::AbstractExpressionNode) return count_depth(tree) isa Int64 end +function _check_is_node_constant(tree::AbstractExpressionNode) + return is_node_constant(tree) isa Bool +end function _check_count_constants(tree::AbstractExpressionNode) return count_constants(tree) isa Int64 end @@ -324,6 +328,7 @@ ni_components = ( branch_hash = "computes the hash of a branch node" => _check_branch_hash, branch_equal = "checks equality of two branch nodes" => _check_branch_equal, count_depth = "calculates the depth of the tree" => _check_count_depth, + is_node_constant = "checks if the node is a constant" => _check_is_node_constant, count_constants = "counts the number of constants" => _check_count_constants, filter_map = "applies a filter and map function to the tree" => _check_filter_map, has_constants = "checks if the tree has constants" => _check_has_constants, From 135b9db23cf6e0a4ef36e50087a93abc96f12a14 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 24 Jun 2024 23:29:36 +0100 Subject: [PATCH 3/3] chore: bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a15ec484..20faf9f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "0.18.1" +version = "0.18.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"