diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea537..f5ba090ac 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,13 +6,18 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # function matcher(val::Any) - iscall(val) && return term_matcher(val) + matcher(val, false) +end + +# `fullac_flag == true` enables fully nested associative-commutative pattern matching +function matcher(val::Any, fullac_flag) + iscall(val) && return term_matcher(val, fullac_flag) function literal_matcher(next, data, bindings) islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing end end -function matcher(slot::Slot) +function matcher(slot::Slot, fullac_flag) # fullac_flag unused but needed to keep the interface uniform function slot_matcher(next, data, bindings) !islist(data) && return val = get(bindings, slot.name, nothing) @@ -56,7 +61,7 @@ function trymatchexpr(data, value, n) end end -function matcher(segment::Segment) +function matcher(segment::Segment, fullac_flag) # fullac_flag unused but needed to keep the interface uniform function segment_matcher(success, data, bindings) val = get(bindings, segment.name, nothing) @@ -84,8 +89,8 @@ function matcher(segment::Segment) end end -function term_matcher(term) - matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) +function term_matcher(term, fullac_flag = false) + matchers = (matcher(operation(term), fullac_flag), map(a -> matcher(a, fullac_flag), arguments(term))...,) function term_matcher(success, data, bindings) !islist(data) && return nothing @@ -103,6 +108,23 @@ function term_matcher(term) end end - loop(car(data), bindings, matchers) # Try to eat exactly one term + if !(fullac_flag && iscall(term) && operation(term) in ((+), (*))) + loop(car(data), bindings, matchers) # Try to eat exactly one term + else # try all permutations of `car(data)` to see if a match is possible + data1 = car(data) + args = arguments(data1) + op = operation(data1) + data_arg_perms = permutations(args) + result = nothing + T = symtype(data) + for perm in data_arg_perms + data_permuted = Term{T}(op, perm) + result = loop(data_permuted, bindings, matchers) # Try to eat exactly one term + if !(result isa Nothing) + break + end + end + return result + end end end diff --git a/src/rule.jl b/src/rule.jl index 5de0aa79c..6ab49eb17 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -297,9 +297,30 @@ whether the predicate holds or not. _In the consequent pattern_: Use `(@ctx)` to access the context object on the right hand side of an expression. + +**Full (nested) associative-commutative matching**: + + @rule LHS => RHS fullac + +creates a rule that fully respects associative-commutative (AC) operations. Unlike `@acrule LHS => RHS` which only considers AC properties of the top-level function, here we impose AC properties on all subexpressions. + +``` +julia> @syms a b; + +julia> r = @rule ~a + ~a*~b => ~a * (1+~b) fullac; + +julia> r(b + a*b) +(1 + a)*b + +``` """ -macro rule(expr) +macro rule(expr, option...) @assert expr.head == :call && expr.args[1] == :(=>) + fullac = false + if length(option) > 0 + @assert option[1] == :fullac "@rule only accepts one option `fullac` after the rule itself" + fullac = true + end lhs = expr.args[2] rhs = rewrite_rhs(expr.args[3]) keys = Symbol[] @@ -310,7 +331,7 @@ macro rule(expr) lhs_pattern = $(lhs_term) Rule($(QuoteNode(expr)), lhs_pattern, - matcher(lhs_pattern), + matcher(lhs_pattern, $fullac), __MATCHES__ -> $(makeconsequent(rhs)), rule_depth($lhs_term)) end diff --git a/test/rewrite.jl b/test/rewrite.jl index 3bb2621e3..8ae28276a 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -43,6 +43,14 @@ end @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) end +@testset "Full associative-commutative matching" begin + @eqtest (@rule ~a + ~a*~b => ~a * (1+~b) fullac)(a + a*b) == a * (1+b) + @eqtest (@rule ~a + ~a*~b => ~a * (1+~b) fullac)(b + a*b) == b * (1+a) # fails with @acrule + @eqtest (@rule ~a*~b + ~a => ~a * (1+~b) fullac)(b + a*b) == b * (1+a) # fails with @acrule + @eqtest (@rule ~a*~b + ~a*~c => ~a * (~b+~c) fullac)(a*b + a*c) == a * (b+c) + @eqtest (@rule ~a*~b + ~a*~c => ~a * (~b+~c) fullac)(a*b + b*c) == b * (a+c) # fails with @acrule +end + using SymbolicUtils: @capture @testset "Capture form" begin