From dd58ef01d38a68f04c93a5ae480e73b104f2559c Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Thu, 5 Dec 2024 15:37:37 +0000 Subject: [PATCH] interpreter: (pdl) add an initial version of EqSat PDLMatcher (#3574) This is an initial attempt of adding a PDL matcher for EqSat e-class ops. --------- Co-authored-by: Jianyi Cheng --- .../test_eqsat_pdl_interpreter.py | 51 +++++++++++++++++++ xdsl/interpreters/eqsat_pdl.py | 22 ++++++++ 2 files changed, 73 insertions(+) create mode 100644 tests/interpreters/test_eqsat_pdl_interpreter.py create mode 100644 xdsl/interpreters/eqsat_pdl.py diff --git a/tests/interpreters/test_eqsat_pdl_interpreter.py b/tests/interpreters/test_eqsat_pdl_interpreter.py new file mode 100644 index 0000000000..9e08851d35 --- /dev/null +++ b/tests/interpreters/test_eqsat_pdl_interpreter.py @@ -0,0 +1,51 @@ +from xdsl.dialects import eqsat, pdl +from xdsl.dialects.builtin import ( + IntegerType, + StringAttr, +) +from xdsl.interpreters.eqsat_pdl import EqsatPDLMatcher + + +def test_match_type(): + matcher = EqsatPDLMatcher() + + pdl_op = pdl.TypeOp() + ssa_value = pdl_op.result + ssa_value = eqsat.EClassOp(ssa_value).result + xdsl_value = StringAttr("a") + + # New value + assert matcher.match_type(ssa_value, pdl_op, xdsl_value) + assert matcher.matching_context == {ssa_value: xdsl_value} + + # Same value + assert matcher.match_type(ssa_value, pdl_op, xdsl_value) + assert matcher.matching_context == {ssa_value: xdsl_value} + + # Other value + assert not matcher.match_type(ssa_value, pdl_op, StringAttr("b")) + assert matcher.matching_context == {ssa_value: xdsl_value} + + +def test_match_fixed_type(): + matcher = EqsatPDLMatcher() + + pdl_op = pdl.TypeOp(IntegerType(32)) + xdsl_value = IntegerType(32) + ssa_value = pdl_op.result + ssa_value = eqsat.EClassOp(ssa_value).result + + assert matcher.match_type(ssa_value, pdl_op, xdsl_value) + assert matcher.matching_context == {ssa_value: xdsl_value} + + +def test_not_match_fixed_type(): + matcher = EqsatPDLMatcher() + + pdl_op = pdl.TypeOp(IntegerType(64)) + xdsl_value = IntegerType(32) + ssa_value = pdl_op.result + ssa_value = eqsat.EClassOp(ssa_value).result + + assert not matcher.match_type(ssa_value, pdl_op, xdsl_value) + assert matcher.matching_context == {} diff --git a/xdsl/interpreters/eqsat_pdl.py b/xdsl/interpreters/eqsat_pdl.py new file mode 100644 index 0000000000..c06dda9ba0 --- /dev/null +++ b/xdsl/interpreters/eqsat_pdl.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from xdsl.dialects import eqsat, pdl +from xdsl.interpreters.pdl import PDLMatcher +from xdsl.ir import SSAValue + + +@dataclass +class EqsatPDLMatcher(PDLMatcher): + def match_operand( + self, ssa_val: SSAValue, pdl_op: pdl.OperandOp, xdsl_val: SSAValue + ): + owner = xdsl_val.owner + assert isinstance(owner, eqsat.EClassOp) + assert ( + len(owner.operands) == 1 + ), "newly converted eqsat always has 1 element in eclass" + arg = owner.operands[0] + res = super().match_operand(ssa_val, pdl_op, arg) + return res