diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index 973ded4818..924f50dfa7 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -28,7 +28,6 @@ ) from xdsl.transforms.experimental.dmp import stencil_global_to_local from xdsl.utils.exceptions import ParseError -from xdsl.utils.parse_pipeline import parse_pipeline @pytest.mark.asyncio @@ -324,12 +323,9 @@ async def test_rewrites(): addi_pass = AvailablePass( display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:SignlessIntegerBinaryOperationZeroOrUnitRight", - module_pass=individual_rewrite.ApplyIndividualRewritePass, - pass_spec=list( - parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' - ) - )[0], + module_pass=individual_rewrite.ApplyIndividualRewritePass( + 3, "arith.addi", "SignlessIntegerBinaryOperationZeroOrUnitRight" + ), ) await pilot.pause() diff --git a/tests/interactive/test_get_all_available_passes.py b/tests/interactive/test_get_all_available_passes.py index 843ddfe87a..2f8c6c6b0d 100644 --- a/tests/interactive/test_get_all_available_passes.py +++ b/tests/interactive/test_get_all_available_passes.py @@ -14,7 +14,6 @@ op_type_rewrite_pattern, ) from xdsl.transforms.individual_rewrite import ApplyIndividualRewritePass -from xdsl.utils.parse_pipeline import PipelinePassSpec @dataclass @@ -82,25 +81,15 @@ def test_get_all_available_passes(): ( AvailablePass( display_name="bc", - module_pass=BCPass, - pass_spec=None, + module_pass=BCPass(), ), AvailablePass( display_name="bd", - module_pass=BDPass, - pass_spec=None, + module_pass=BDPass(), ), AvailablePass( display_name='TestOp("test.op"() {key = "b"} : () -> ()):test.op:be', - module_pass=ApplyIndividualRewritePass, - pass_spec=PipelinePassSpec( - "apply-individual-rewrite", - { - "matched_operation_index": (1,), - "operation_name": ("test.op",), - "pattern_name": ("be",), - }, - ), + module_pass=ApplyIndividualRewritePass(1, "test.op", "be"), ), ) ) diff --git a/tests/interactive/test_rewrites.py b/tests/interactive/test_rewrites.py index 79fab78a63..a661272341 100644 --- a/tests/interactive/test_rewrites.py +++ b/tests/interactive/test_rewrites.py @@ -13,7 +13,6 @@ op_type_rewrite_pattern, ) from xdsl.transforms.individual_rewrite import ApplyIndividualRewritePass -from xdsl.utils.parse_pipeline import PipelinePassSpec class Rewrite(RewritePattern): @@ -44,27 +43,11 @@ def test_get_all_possible_rewrite(): expected_res = [ AvailablePass( display_name='TestOp("test.op"() {label = "a"} : () -> ()):test.op:TestRewrite', - module_pass=ApplyIndividualRewritePass, - pass_spec=PipelinePassSpec( - "apply-individual-rewrite", - { - "matched_operation_index": (1,), - "operation_name": ("test.op",), - "pattern_name": ("TestRewrite",), - }, - ), + module_pass=ApplyIndividualRewritePass(1, "test.op", "TestRewrite"), ), AvailablePass( display_name='TestOp("test.op"() {label = "a"} : () -> ()):test.op:TestRewrite', - module_pass=ApplyIndividualRewritePass, - pass_spec=PipelinePassSpec( - "apply-individual-rewrite", - { - "matched_operation_index": (2,), - "operation_name": ("test.op",), - "pattern_name": ("TestRewrite",), - }, - ), + module_pass=ApplyIndividualRewritePass(2, "test.op", "TestRewrite"), ), ] diff --git a/xdsl/interactive/app.py b/xdsl/interactive/app.py index 239d0057e5..68f602ba11 100644 --- a/xdsl/interactive/app.py +++ b/xdsl/interactive/app.py @@ -256,7 +256,7 @@ def compute_available_pass_list(self) -> tuple[AvailablePass, ...]: """ match self.current_module: case None: - return tuple(AvailablePass(p.name, p, None) for _, p in self.all_passes) + return tuple(AvailablePass(p.name, p) for _, p in self.all_passes) case Exception(): return () case ModuleOp(): @@ -347,10 +347,12 @@ def expand_node( # remove potential children nodes in case expand node has been clicked multiple times on the same node expanded_pass.remove_children() - for pass_name, value, value_spec in child_pass_list: + for pass_name, value in child_pass_list: expanded_pass.add( label=pass_name, - data=(value, value_spec), + data=(type(value), value.pipeline_pass_spec()) + if isinstance(value, ModulePass) + else (value, None), ) def update_root_of_passes_tree(self) -> None: diff --git a/xdsl/interactive/get_all_available_passes.py b/xdsl/interactive/get_all_available_passes.py index f9b1cffb73..c235ffbdcf 100644 --- a/xdsl/interactive/get_all_available_passes.py +++ b/xdsl/interactive/get_all_available_passes.py @@ -39,5 +39,5 @@ def get_available_pass_list( if condense_mode: pass_list = get_condensed_pass_list(current_module, all_passes) else: - pass_list = tuple(AvailablePass(p.name, p, None) for _, p in all_passes) + pass_list = tuple(AvailablePass(p.name, p) for _, p in all_passes) return pass_list + tuple(individual_rewrites) diff --git a/xdsl/interactive/passes.py b/xdsl/interactive/passes.py index 37384e368a..07b56c251f 100644 --- a/xdsl/interactive/passes.py +++ b/xdsl/interactive/passes.py @@ -6,7 +6,6 @@ from xdsl.ir import Dialect from xdsl.passes import ModulePass, PipelinePass from xdsl.transforms.mlir_opt import MLIROptPass -from xdsl.utils.parse_pipeline import PipelinePassSpec class AvailablePass(NamedTuple): @@ -16,8 +15,7 @@ class AvailablePass(NamedTuple): """ display_name: str - module_pass: type[ModulePass] - pass_spec: PipelinePassSpec | None + module_pass: type[ModulePass] | ModulePass def get_new_registered_context( @@ -55,19 +53,21 @@ def iter_condensed_passes( for dialect_name, dialect_factory in get_all_dialects().items(): ctx.register_dialect(dialect_name, dialect_factory) - for _, value in all_passes: - if value is MLIROptPass: + for _, pass_type in all_passes: + if pass_type is MLIROptPass: # Always keep MLIROptPass as an option in condensed list - yield AvailablePass(value.name, value, None), None + yield AvailablePass(pass_type.name, pass_type), None + continue + cloned_module = input.clone() + cloned_ctx = ctx.clone() try: - cloned_module = input.clone() - cloned_ctx = ctx.clone() - value().apply(cloned_ctx, cloned_module) + pass_instance = pass_type() + pass_instance.apply(cloned_ctx, cloned_module) if input.is_structurally_equivalent(cloned_module): continue - yield AvailablePass(value.name, value, None), cloned_module except Exception: - pass + continue + yield AvailablePass(pass_type.name, pass_instance), cloned_module def get_condensed_pass_list( diff --git a/xdsl/interactive/rewrites.py b/xdsl/interactive/rewrites.py index 50ca4ed908..ce9a96d09d 100644 --- a/xdsl/interactive/rewrites.py +++ b/xdsl/interactive/rewrites.py @@ -33,14 +33,12 @@ def get_all_possible_rewrites( rewriter = PatternRewriter(cloned_op) pattern.match_and_rewrite(cloned_op, rewriter) if rewriter.has_done_action: - p = individual_rewrite.ApplyIndividualRewritePass( - op_idx, cloned_op.name, pattern_name - ) res.append( AvailablePass( f"{cloned_op}:{cloned_op.name}:{pattern_name}", - individual_rewrite.ApplyIndividualRewritePass, - p.pipeline_pass_spec(), + individual_rewrite.ApplyIndividualRewritePass( + op_idx, cloned_op.name, pattern_name + ), ) )