diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 5937ddad..a4f0e5f3 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -361,6 +361,7 @@ function Base.union!( ) g.classes[id_1] = new_eclass + modify!(g, new_eclass) return true end @@ -425,21 +426,23 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp eclass = g.classes[eclass_id_key] node_data = make(g, node) - if !isnothing(eclass.data) - joined_data = join(eclass.data, node_data) - - if joined_data != eclass.data - g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data) - # eclass.data = joined_data + if !isnothing(node_data) + if !isnothing(eclass.data) + joined_data = join(eclass.data, node_data) + + if joined_data != eclass.data + g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data) + # eclass.data = joined_data + modify!(g, eclass) + append!(g.analysis_pending, eclass.parents) + end + else + g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, node_data) + # eclass.data = node_data modify!(g, eclass) append!(g.analysis_pending, eclass.parents) end - else - g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, node_data) - # eclass.data = node_data - modify!(g, eclass) end - end end n_unions diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 41b2702f..fcc3555f 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -86,9 +86,8 @@ function eqsat_search!( @debug "$rule is banned" continue end - ids_left = cached_ids(g, rule.left) - ids_right = is_bidirectional(rule) ? cached_ids(g, rule.right) : UNDEF_ID_VEC + ids_left = cached_ids(g, rule.left) for i in ids_left cansearch(scheduler, rule_idx, i) || continue n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer) @@ -96,6 +95,7 @@ function eqsat_search!( end if is_bidirectional(rule) + ids_right = cached_ids(g, rule.right) for i in ids_right cansearch(scheduler, rule_idx, i) || continue n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer) diff --git a/src/utils.jl b/src/utils.jl index 2e03b08d..377132e5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -36,7 +36,7 @@ function buffer_readable(g, limit, ematch_buffer) k = length(ematch_buffer) while k > limit - delimiter = ematch_buffer[k] + delimiter = ematch_buffer.v[k] @assert delimiter == 0xffffffffffffffffffffffffffffffff n = k - 1 @@ -44,19 +44,19 @@ function buffer_readable(g, limit, ematch_buffer) n_elems = 0 for i in n:-1:1 n_elems += 1 - if ematch_buffer[i] == 0xffffffffffffffffffffffffffffffff + if ematch_buffer.v[i] == 0xffffffffffffffffffffffffffffffff n_elems -= 1 next_delimiter_idx = i break end end - match_info = ematch_buffer[next_delimiter_idx + 1] + match_info = ematch_buffer.v[next_delimiter_idx + 1] id = v_pair_first(match_info) rule_idx = reinterpret(Int, v_pair_last(match_info)) rule_idx = abs(rule_idx) - bindings = @view ematch_buffer[(next_delimiter_idx + 2):n] + bindings = @view ematch_buffer.v[(next_delimiter_idx + 2):n] print("$id E-Classes: ", map(x -> reinterpret(Int, v_pair_first(x)), bindings)) print(" Nodes: ", map(x -> reinterpret(Int, v_pair_last(x)), bindings), "\n") diff --git a/test/egraphs/ematch.jl b/test/egraphs/ematch.jl index d2c5fb6d..414c7439 100644 --- a/test/egraphs/ematch.jl +++ b/test/egraphs/ematch.jl @@ -289,6 +289,22 @@ end @test test_equality(some_theory, :(a * b * 0), 0) end +@testset "Dynamic rule predicates in EMatcher" begin + g = EGraph(:(2 * 3)) + zero_id = addexpr!(g, 0) + + some_theory = @theory begin + ~a * ~b => 0 where (iszero(a) || iszero(b)) + ~a * ~b --> ~b * ~a + end + + Base.iszero(ec::EClass) = in_same_class(g, zero_id, ec.id) + + saturate!(g, some_theory) + + @test test_equality(some_theory, :(a * b * 0), 0) +end + @testset "Inequalities" begin failme = @theory p begin p != !p diff --git a/test/integration/stream_fusion.jl b/test/integration/stream_fusion.jl index 5b24d3e7..b119c069 100644 --- a/test/integration/stream_fusion.jl +++ b/test/integration/stream_fusion.jl @@ -49,7 +49,7 @@ fold_theory = @theory x y z begin x::Number * y::Number => x * y x::Number + y::Number => x + y x::Number / y::Number => x / y - x::Number - y::Number => x / y + x::Number - y::Number => x - y # etc... end