Skip to content

Commit

Permalink
Merge pull request #243 from JuliaSymbolics/3.0_minor_fixes_and_impro…
Browse files Browse the repository at this point in the history
…vements

3.0 minor fixes and improvements
  • Loading branch information
0x0f0f0f authored Sep 29, 2024
2 parents dab995c + 3373cb6 commit 081a9e6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 18 deletions.
25 changes: 14 additions & 11 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ function Base.union!(
)

g.classes[id_1] = new_eclass
modify!(g, new_eclass)

return true
end
Expand Down Expand Up @@ -425,21 +426,23 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
eclass = g.classes[eclass_id_key]

node_data = make(g, node)
if !isnothing(eclass.data)
joined_data = join(eclass.data, node_data)

if joined_data != eclass.data
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data)
# eclass.data = joined_data
if !isnothing(node_data)
if !isnothing(eclass.data)
joined_data = join(eclass.data, node_data)

if joined_data != eclass.data
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data)
# eclass.data = joined_data
modify!(g, eclass)
append!(g.analysis_pending, eclass.parents)
end
else
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, node_data)
# eclass.data = node_data
modify!(g, eclass)
append!(g.analysis_pending, eclass.parents)
end
else
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, node_data)
# eclass.data = node_data
modify!(g, eclass)
end

end
end
n_unions
Expand Down
4 changes: 2 additions & 2 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ function eqsat_search!(
@debug "$rule is banned"
continue
end
ids_left = cached_ids(g, rule.left)
ids_right = is_bidirectional(rule) ? cached_ids(g, rule.right) : UNDEF_ID_VEC

ids_left = cached_ids(g, rule.left)
for i in ids_left
cansearch(scheduler, rule_idx, i) || continue
n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer)
inform!(scheduler, rule_idx, i, n_matches)
end

if is_bidirectional(rule)
ids_right = cached_ids(g, rule.right)
for i in ids_right
cansearch(scheduler, rule_idx, i) || continue
n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer)
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,27 @@ function buffer_readable(g, limit, ematch_buffer)
k = length(ematch_buffer)

while k > limit
delimiter = ematch_buffer[k]
delimiter = ematch_buffer.v[k]
@assert delimiter == 0xffffffffffffffffffffffffffffffff
n = k - 1

next_delimiter_idx = 0
n_elems = 0
for i in n:-1:1
n_elems += 1
if ematch_buffer[i] == 0xffffffffffffffffffffffffffffffff
if ematch_buffer.v[i] == 0xffffffffffffffffffffffffffffffff
n_elems -= 1
next_delimiter_idx = i
break
end
end

match_info = ematch_buffer[next_delimiter_idx + 1]
match_info = ematch_buffer.v[next_delimiter_idx + 1]
id = v_pair_first(match_info)
rule_idx = reinterpret(Int, v_pair_last(match_info))
rule_idx = abs(rule_idx)

bindings = @view ematch_buffer[(next_delimiter_idx + 2):n]
bindings = @view ematch_buffer.v[(next_delimiter_idx + 2):n]

print("$id E-Classes: ", map(x -> reinterpret(Int, v_pair_first(x)), bindings))
print(" Nodes: ", map(x -> reinterpret(Int, v_pair_last(x)), bindings), "\n")
Expand Down
16 changes: 16 additions & 0 deletions test/egraphs/ematch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,22 @@ end
@test test_equality(some_theory, :(a * b * 0), 0)
end

@testset "Dynamic rule predicates in EMatcher" begin
g = EGraph(:(2 * 3))
zero_id = addexpr!(g, 0)

some_theory = @theory begin
~a * ~b => 0 where (iszero(a) || iszero(b))
~a * ~b --> ~b * ~a
end

Base.iszero(ec::EClass) = in_same_class(g, zero_id, ec.id)

saturate!(g, some_theory)

@test test_equality(some_theory, :(a * b * 0), 0)
end

@testset "Inequalities" begin
failme = @theory p begin
p != !p
Expand Down
2 changes: 1 addition & 1 deletion test/integration/stream_fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fold_theory = @theory x y z begin
x::Number * y::Number => x * y
x::Number + y::Number => x + y
x::Number / y::Number => x / y
x::Number - y::Number => x / y
x::Number - y::Number => x - y
# etc...
end

Expand Down

0 comments on commit 081a9e6

Please sign in to comment.