diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index 3e116aed3a..973ded4818 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -28,7 +28,7 @@ ) from xdsl.transforms.experimental.dmp import stencil_global_to_local from xdsl.utils.exceptions import ParseError -from xdsl.utils.parse_pipeline import PipelinePassSpec, parse_pipeline +from xdsl.utils.parse_pipeline import parse_pipeline @pytest.mark.asyncio @@ -167,18 +167,12 @@ async def test_buttons(): # Select two passes app.pass_pipeline = ( *app.pass_pipeline, - ( - convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass, - PipelinePassSpec(name="convert-func-to-riscv-func", args={}), - ), + convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(), ) app.pass_pipeline = ( *app.pass_pipeline, - ( - convert_arith_to_riscv.ConvertArithToRiscvPass, - PipelinePassSpec(name="convert-arith-to-riscv", args={}), - ), + convert_arith_to_riscv.ConvertArithToRiscvPass(), ) # assert that pass selection affected Output Text Area @@ -350,13 +344,8 @@ async def test_rewrites(): # Select a rewrite app.pass_pipeline = ( *app.pass_pipeline, - ( - individual_rewrite.ApplyIndividualRewritePass, - list( - parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' - ) - )[0], + individual_rewrite.ApplyIndividualRewritePass( + 3, "arith.addi", "SignlessIntegerBinaryOperationZeroOrUnitRight" ), ) @@ -414,10 +403,7 @@ async def test_passes(): # Select a pass app.pass_pipeline = ( *app.pass_pipeline, - ( - convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass, - PipelinePassSpec(name="convert-func-to-riscv-func", args={}), - ), + convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(), ) # assert that the Output Text Area has changed accordingly await pilot.pause() diff --git a/tests/interactive/test_get_all_available_passes.py b/tests/interactive/test_get_all_available_passes.py index 66a2f83792..7949888966 100644 --- a/tests/interactive/test_get_all_available_passes.py +++ b/tests/interactive/test_get_all_available_passes.py @@ -68,12 +68,7 @@ def test_get_all_available_passes(): tuple((p.name, p) for p in (ABPass, ACPass, BCPass, BDPass)), '"test.op"() {key="a"} : () -> ()', # Transforms the above op from "a" to "b" before testing passes - ( - ( - ABPass, - PipelinePassSpec(name="ab", args={}), - ), - ), + (ABPass(),), condense_mode=True, rewrite_by_names_dict={ "test.op": { diff --git a/xdsl/interactive/app.py b/xdsl/interactive/app.py index 1fc7aa58c9..239d0057e5 100644 --- a/xdsl/interactive/app.py +++ b/xdsl/interactive/app.py @@ -103,7 +103,7 @@ class InputApp(App[None]): Reactive variable used to save the current state of the modified Input TextArea (i.e. is the Output TextArea). """ - pass_pipeline = reactive(tuple[tuple[type[ModulePass], PipelinePassSpec], ...]) + pass_pipeline = reactive(tuple[ModulePass, ...]) """Reactive variable that saves the list of selected passes.""" condense_mode = reactive(False, always_update=True) @@ -144,7 +144,7 @@ class InputApp(App[None]): pre_loaded_input_text: str current_file_path: str - pre_loaded_pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...] + pre_loaded_pass_pipeline: tuple[ModulePass, ...] def __init__( self, @@ -152,7 +152,7 @@ def __init__( all_passes: tuple[tuple[str, type[ModulePass]], ...], file_path: str | None = None, input_text: str | None = None, - pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...] = (), + pass_pipeline: tuple[ModulePass, ...] = (), ): self.all_dialects = all_dialects self.all_passes = all_passes @@ -324,7 +324,9 @@ def update_selected_passes_list_view(self) -> None: # last element is the node of the tree pass_pipeline = self.pass_pipeline[:-1] - for pass_value, value_spec in pass_pipeline: + for p in pass_pipeline: + pass_value = type(p) + value_spec = p.pipeline_pass_spec() self.selected_passes_list_view.append( PassListItem( Label(str(value_spec)), @@ -358,13 +360,13 @@ def update_root_of_passes_tree(self) -> None: updates the subtree of the root. """ # reset rootnode of tree - if self.pass_pipeline == (): + if not self.pass_pipeline: self.passes_tree.reset(".") else: - value, value_spec = self.pass_pipeline[-1] + p = self.pass_pipeline[-1] self.passes_tree.reset( - label=str(value_spec), - data=(value, value_spec), + label=str(p.pipeline_pass_spec()), + data=(type(p), p.pipeline_pass_spec()), ) # expand the node self.expand_node(self.passes_tree.root, self.available_pass_list) @@ -409,8 +411,8 @@ def add_pass_with_arguments_to_pass_pipeline( else: self.pass_pipeline = ( *self.pass_pipeline, - *root_to_child_pass_list, - (selected_pass_value, new_pass_with_arguments), + *tuple(p.from_pass_spec(s) for p, s in root_to_child_pass_list), + selected_pass_value.from_pass_spec(new_pass_with_arguments), ) return @@ -463,8 +465,8 @@ def update_pass_pipeline( # selected_pass_value is an "individual_rewrite", add the selected pass to pass_pipeline self.pass_pipeline = ( *self.pass_pipeline, - *root_to_child_pass_list, - (selected_pass_value, selected_pass_spec), + *tuple(p.from_pass_spec(s) for p, s in root_to_child_pass_list), + selected_pass_value.from_pass_spec(selected_pass_spec), ) @on(Tree.NodeExpanded, "#passes_tree") @@ -492,12 +494,15 @@ def expand_tree_node( selected_pass_spec = selected_pass_value().pipeline_pass_spec() # if selected_pass_value requires no arguments add the selected pass to pass_pipeline - root_to_child_pass_list = self.get_root_to_child_pass_list(expanded_node) + root_to_child_pass_list = tuple( + p.from_pass_spec(s) + for p, s in self.get_root_to_child_pass_list(expanded_node) + ) child_pass_pipeline = ( *self.pass_pipeline, *root_to_child_pass_list, - (selected_pass_value, selected_pass_spec), + selected_pass_value.from_pass_spec(selected_pass_spec), ) child_pass_list = get_available_pass_list( @@ -575,9 +580,7 @@ def get_query_string(self) -> str: if self.pass_pipeline: query += "'" - query += ",".join( - str(pipeline_pass_spec) for _, pipeline_pass_spec in self.pass_pipeline - ) + query += ",".join(str(p.pipeline_pass_spec()) for p in self.pass_pipeline) query += "'" return f"xdsl-opt {query}" @@ -659,7 +662,7 @@ def copy_query(self, event: Button.Pressed) -> None: @on(Button.Pressed, "#clear_passes_button") def clear_passes(self, event: Button.Pressed) -> None: """Selected passes cleared when "Clear Passes" button is pressed.""" - self.pass_pipeline = tuple[tuple[type[ModulePass], PipelinePassSpec], ...]() + self.pass_pipeline = () @on(Button.Pressed, "#condense_button") def condense(self, event: Button.Pressed) -> None: @@ -756,7 +759,12 @@ def main(): pass_spec_pipeline = list(parse_pipeline(args.passes)) pass_list = get_all_passes() - pipeline = tuple(PipelinePass.build_pipeline_tuples(pass_list, pass_spec_pipeline)) + pipeline = tuple( + pass_type.from_pass_spec(spec) + for pass_type, spec in PipelinePass.build_pipeline_tuples( + pass_list, pass_spec_pipeline + ) + ) return InputApp( tuple(get_all_dialects().items()), diff --git a/xdsl/interactive/get_all_available_passes.py b/xdsl/interactive/get_all_available_passes.py index 3547df2a7d..f9b1cffb73 100644 --- a/xdsl/interactive/get_all_available_passes.py +++ b/xdsl/interactive/get_all_available_passes.py @@ -11,14 +11,13 @@ from xdsl.parser import Parser from xdsl.passes import ModulePass from xdsl.pattern_rewriter import RewritePattern -from xdsl.utils.parse_pipeline import PipelinePassSpec def get_available_pass_list( all_dialects: tuple[tuple[str, Callable[[], Dialect]], ...], all_passes: tuple[tuple[str, type[ModulePass]], ...], input_text: str, - pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...], + pass_pipeline: tuple[ModulePass, ...], condense_mode: bool, rewrite_by_names_dict: dict[str, dict[str, RewritePattern]], ) -> tuple[AvailablePass, ...]: diff --git a/xdsl/interactive/passes.py b/xdsl/interactive/passes.py index 3e3387d540..37384e368a 100644 --- a/xdsl/interactive/passes.py +++ b/xdsl/interactive/passes.py @@ -35,17 +35,13 @@ def get_new_registered_context( def apply_passes_to_module( module: builtin.ModuleOp, ctx: MLContext, - pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...], + passes: tuple[ModulePass, ...], ) -> builtin.ModuleOp: """ - Function that takes a ModuleOp, an MLContext and a pass_pipeline (consisting of a type[ModulePass] and PipelinePassSpec), applies the pass(es) to the ModuleOp and returns the new ModuleOp. + Function that takes a ModuleOp, an MLContext and a pass_pipeline, applies the + passes to the ModuleOp and returns the modified ModuleOp. """ - pipeline = PipelinePass( - passes=tuple( - module_pass.from_pass_spec(pipeline_pass_spec) - for module_pass, pipeline_pass_spec in pass_pipeline - ) - ) + pipeline = PipelinePass(passes=passes) pipeline.apply(ctx, module) return module