Skip to content

Commit

Permalink
interactive: store passes in app instead of class+spec tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Jan 8, 2025
1 parent 460bd6e commit 1038494
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 55 deletions.
26 changes: 6 additions & 20 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from xdsl.transforms.experimental.dmp import stencil_global_to_local
from xdsl.utils.exceptions import ParseError
from xdsl.utils.parse_pipeline import PipelinePassSpec, parse_pipeline
from xdsl.utils.parse_pipeline import parse_pipeline


@pytest.mark.asyncio
Expand Down Expand Up @@ -167,18 +167,12 @@ async def test_buttons():
# Select two passes
app.pass_pipeline = (
*app.pass_pipeline,
(
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
PipelinePassSpec(name="convert-func-to-riscv-func", args={}),
),
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(),
)

app.pass_pipeline = (
*app.pass_pipeline,
(
convert_arith_to_riscv.ConvertArithToRiscvPass,
PipelinePassSpec(name="convert-arith-to-riscv", args={}),
),
convert_arith_to_riscv.ConvertArithToRiscvPass(),
)

# assert that pass selection affected Output Text Area
Expand Down Expand Up @@ -350,13 +344,8 @@ async def test_rewrites():
# Select a rewrite
app.pass_pipeline = (
*app.pass_pipeline,
(
individual_rewrite.ApplyIndividualRewritePass,
list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
)
)[0],
individual_rewrite.ApplyIndividualRewritePass(
3, "arith.addi", "SignlessIntegerBinaryOperationZeroOrUnitRight"
),
)

Expand Down Expand Up @@ -414,10 +403,7 @@ async def test_passes():
# Select a pass
app.pass_pipeline = (
*app.pass_pipeline,
(
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
PipelinePassSpec(name="convert-func-to-riscv-func", args={}),
),
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(),
)
# assert that the Output Text Area has changed accordingly
await pilot.pause()
Expand Down
7 changes: 1 addition & 6 deletions tests/interactive/test_get_all_available_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,7 @@ def test_get_all_available_passes():
tuple((p.name, p) for p in (ABPass, ACPass, BCPass, BDPass)),
'"test.op"() {key="a"} : () -> ()',
# Transforms the above op from "a" to "b" before testing passes
(
(
ABPass,
PipelinePassSpec(name="ab", args={}),
),
),
(ABPass(),),
condense_mode=True,
rewrite_by_names_dict={
"test.op": {
Expand Down
46 changes: 27 additions & 19 deletions xdsl/interactive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class InputApp(App[None]):
Reactive variable used to save the current state of the modified Input TextArea
(i.e. is the Output TextArea).
"""
pass_pipeline = reactive(tuple[tuple[type[ModulePass], PipelinePassSpec], ...])
pass_pipeline = reactive(tuple[ModulePass, ...])
"""Reactive variable that saves the list of selected passes."""

condense_mode = reactive(False, always_update=True)
Expand Down Expand Up @@ -144,15 +144,15 @@ class InputApp(App[None]):

pre_loaded_input_text: str
current_file_path: str
pre_loaded_pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...]
pre_loaded_pass_pipeline: tuple[ModulePass, ...]

def __init__(
self,
all_dialects: tuple[tuple[str, Callable[[], Dialect]], ...],
all_passes: tuple[tuple[str, type[ModulePass]], ...],
file_path: str | None = None,
input_text: str | None = None,
pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...] = (),
pass_pipeline: tuple[ModulePass, ...] = (),
):
self.all_dialects = all_dialects
self.all_passes = all_passes
Expand Down Expand Up @@ -324,7 +324,9 @@ def update_selected_passes_list_view(self) -> None:

# last element is the node of the tree
pass_pipeline = self.pass_pipeline[:-1]
for pass_value, value_spec in pass_pipeline:
for p in pass_pipeline:
pass_value = type(p)
value_spec = p.pipeline_pass_spec()
self.selected_passes_list_view.append(
PassListItem(
Label(str(value_spec)),
Expand Down Expand Up @@ -358,13 +360,13 @@ def update_root_of_passes_tree(self) -> None:
updates the subtree of the root.
"""
# reset rootnode of tree
if self.pass_pipeline == ():
if not self.pass_pipeline:
self.passes_tree.reset(".")
else:
value, value_spec = self.pass_pipeline[-1]
p = self.pass_pipeline[-1]
self.passes_tree.reset(
label=str(value_spec),
data=(value, value_spec),
label=str(p.pipeline_pass_spec()),
data=(type(p), p.pipeline_pass_spec()),
)
# expand the node
self.expand_node(self.passes_tree.root, self.available_pass_list)
Expand Down Expand Up @@ -409,8 +411,8 @@ def add_pass_with_arguments_to_pass_pipeline(
else:
self.pass_pipeline = (
*self.pass_pipeline,
*root_to_child_pass_list,
(selected_pass_value, new_pass_with_arguments),
*tuple(p.from_pass_spec(s) for p, s in root_to_child_pass_list),
selected_pass_value.from_pass_spec(new_pass_with_arguments),
)
return

Expand Down Expand Up @@ -463,8 +465,8 @@ def update_pass_pipeline(
# selected_pass_value is an "individual_rewrite", add the selected pass to pass_pipeline
self.pass_pipeline = (
*self.pass_pipeline,
*root_to_child_pass_list,
(selected_pass_value, selected_pass_spec),
*tuple(p.from_pass_spec(s) for p, s in root_to_child_pass_list),
selected_pass_value.from_pass_spec(selected_pass_spec),
)

@on(Tree.NodeExpanded, "#passes_tree")
Expand Down Expand Up @@ -492,12 +494,15 @@ def expand_tree_node(
selected_pass_spec = selected_pass_value().pipeline_pass_spec()

# if selected_pass_value requires no arguments add the selected pass to pass_pipeline
root_to_child_pass_list = self.get_root_to_child_pass_list(expanded_node)
root_to_child_pass_list = tuple(
p.from_pass_spec(s)
for p, s in self.get_root_to_child_pass_list(expanded_node)
)

child_pass_pipeline = (
*self.pass_pipeline,
*root_to_child_pass_list,
(selected_pass_value, selected_pass_spec),
selected_pass_value.from_pass_spec(selected_pass_spec),
)

child_pass_list = get_available_pass_list(
Expand Down Expand Up @@ -575,9 +580,7 @@ def get_query_string(self) -> str:

if self.pass_pipeline:
query += "'"
query += ",".join(
str(pipeline_pass_spec) for _, pipeline_pass_spec in self.pass_pipeline
)
query += ",".join(str(p.pipeline_pass_spec()) for p in self.pass_pipeline)
query += "'"
return f"xdsl-opt {query}"

Expand Down Expand Up @@ -659,7 +662,7 @@ def copy_query(self, event: Button.Pressed) -> None:
@on(Button.Pressed, "#clear_passes_button")
def clear_passes(self, event: Button.Pressed) -> None:
"""Selected passes cleared when "Clear Passes" button is pressed."""
self.pass_pipeline = tuple[tuple[type[ModulePass], PipelinePassSpec], ...]()
self.pass_pipeline = ()

@on(Button.Pressed, "#condense_button")
def condense(self, event: Button.Pressed) -> None:
Expand Down Expand Up @@ -756,7 +759,12 @@ def main():

pass_spec_pipeline = list(parse_pipeline(args.passes))
pass_list = get_all_passes()
pipeline = tuple(PipelinePass.build_pipeline_tuples(pass_list, pass_spec_pipeline))
pipeline = tuple(
pass_type.from_pass_spec(spec)
for pass_type, spec in PipelinePass.build_pipeline_tuples(
pass_list, pass_spec_pipeline
)
)

return InputApp(
tuple(get_all_dialects().items()),
Expand Down
3 changes: 1 addition & 2 deletions xdsl/interactive/get_all_available_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
from xdsl.parser import Parser
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import RewritePattern
from xdsl.utils.parse_pipeline import PipelinePassSpec


def get_available_pass_list(
all_dialects: tuple[tuple[str, Callable[[], Dialect]], ...],
all_passes: tuple[tuple[str, type[ModulePass]], ...],
input_text: str,
pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...],
pass_pipeline: tuple[ModulePass, ...],
condense_mode: bool,
rewrite_by_names_dict: dict[str, dict[str, RewritePattern]],
) -> tuple[AvailablePass, ...]:
Expand Down
12 changes: 4 additions & 8 deletions xdsl/interactive/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,13 @@ def get_new_registered_context(
def apply_passes_to_module(
module: builtin.ModuleOp,
ctx: MLContext,
pass_pipeline: tuple[tuple[type[ModulePass], PipelinePassSpec], ...],
passes: tuple[ModulePass, ...],
) -> builtin.ModuleOp:
"""
Function that takes a ModuleOp, an MLContext and a pass_pipeline (consisting of a type[ModulePass] and PipelinePassSpec), applies the pass(es) to the ModuleOp and returns the new ModuleOp.
Function that takes a ModuleOp, an MLContext and a pass_pipeline, applies the
passes to the ModuleOp and returns the modified ModuleOp.
"""
pipeline = PipelinePass(
passes=tuple(
module_pass.from_pass_spec(pipeline_pass_spec)
for module_pass, pipeline_pass_spec in pass_pipeline
)
)
pipeline = PipelinePass(passes=passes)
pipeline.apply(ctx, module)
return module

Expand Down

0 comments on commit 1038494

Please sign in to comment.