Skip to content

Commit

Permalink
interactive: AvailablePass now has either a pass type or instance, no…
Browse files Browse the repository at this point in the history
… spec (#3725)

Same deal as #3723, we used to need the spec due to potential
mutability, now don't need it.
  • Loading branch information
superlopuh authored Jan 8, 2025
1 parent be78282 commit 0d39a26
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 60 deletions.
10 changes: 3 additions & 7 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from xdsl.transforms.experimental.dmp import stencil_global_to_local
from xdsl.utils.exceptions import ParseError
from xdsl.utils.parse_pipeline import parse_pipeline


@pytest.mark.asyncio
Expand Down Expand Up @@ -324,12 +323,9 @@ async def test_rewrites():

addi_pass = AvailablePass(
display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:SignlessIntegerBinaryOperationZeroOrUnitRight",
module_pass=individual_rewrite.ApplyIndividualRewritePass,
pass_spec=list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
)
)[0],
module_pass=individual_rewrite.ApplyIndividualRewritePass(
3, "arith.addi", "SignlessIntegerBinaryOperationZeroOrUnitRight"
),
)

await pilot.pause()
Expand Down
17 changes: 3 additions & 14 deletions tests/interactive/test_get_all_available_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
op_type_rewrite_pattern,
)
from xdsl.transforms.individual_rewrite import ApplyIndividualRewritePass
from xdsl.utils.parse_pipeline import PipelinePassSpec


@dataclass
Expand Down Expand Up @@ -82,25 +81,15 @@ def test_get_all_available_passes():
(
AvailablePass(
display_name="bc",
module_pass=BCPass,
pass_spec=None,
module_pass=BCPass(),
),
AvailablePass(
display_name="bd",
module_pass=BDPass,
pass_spec=None,
module_pass=BDPass(),
),
AvailablePass(
display_name='TestOp("test.op"() {key = "b"} : () -> ()):test.op:be',
module_pass=ApplyIndividualRewritePass,
pass_spec=PipelinePassSpec(
"apply-individual-rewrite",
{
"matched_operation_index": (1,),
"operation_name": ("test.op",),
"pattern_name": ("be",),
},
),
module_pass=ApplyIndividualRewritePass(1, "test.op", "be"),
),
)
)
21 changes: 2 additions & 19 deletions tests/interactive/test_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
op_type_rewrite_pattern,
)
from xdsl.transforms.individual_rewrite import ApplyIndividualRewritePass
from xdsl.utils.parse_pipeline import PipelinePassSpec


class Rewrite(RewritePattern):
Expand Down Expand Up @@ -44,27 +43,11 @@ def test_get_all_possible_rewrite():
expected_res = [
AvailablePass(
display_name='TestOp("test.op"() {label = "a"} : () -> ()):test.op:TestRewrite',
module_pass=ApplyIndividualRewritePass,
pass_spec=PipelinePassSpec(
"apply-individual-rewrite",
{
"matched_operation_index": (1,),
"operation_name": ("test.op",),
"pattern_name": ("TestRewrite",),
},
),
module_pass=ApplyIndividualRewritePass(1, "test.op", "TestRewrite"),
),
AvailablePass(
display_name='TestOp("test.op"() {label = "a"} : () -> ()):test.op:TestRewrite',
module_pass=ApplyIndividualRewritePass,
pass_spec=PipelinePassSpec(
"apply-individual-rewrite",
{
"matched_operation_index": (2,),
"operation_name": ("test.op",),
"pattern_name": ("TestRewrite",),
},
),
module_pass=ApplyIndividualRewritePass(2, "test.op", "TestRewrite"),
),
]

Expand Down
8 changes: 5 additions & 3 deletions xdsl/interactive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def compute_available_pass_list(self) -> tuple[AvailablePass, ...]:
"""
match self.current_module:
case None:
return tuple(AvailablePass(p.name, p, None) for _, p in self.all_passes)
return tuple(AvailablePass(p.name, p) for _, p in self.all_passes)
case Exception():
return ()
case ModuleOp():
Expand Down Expand Up @@ -347,10 +347,12 @@ def expand_node(
# remove potential children nodes in case expand node has been clicked multiple times on the same node
expanded_pass.remove_children()

for pass_name, value, value_spec in child_pass_list:
for pass_name, value in child_pass_list:
expanded_pass.add(
label=pass_name,
data=(value, value_spec),
data=(type(value), value.pipeline_pass_spec())
if isinstance(value, ModulePass)
else (value, None),
)

def update_root_of_passes_tree(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion xdsl/interactive/get_all_available_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ def get_available_pass_list(
if condense_mode:
pass_list = get_condensed_pass_list(current_module, all_passes)
else:
pass_list = tuple(AvailablePass(p.name, p, None) for _, p in all_passes)
pass_list = tuple(AvailablePass(p.name, p) for _, p in all_passes)
return pass_list + tuple(individual_rewrites)
22 changes: 11 additions & 11 deletions xdsl/interactive/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xdsl.ir import Dialect
from xdsl.passes import ModulePass, PipelinePass
from xdsl.transforms.mlir_opt import MLIROptPass
from xdsl.utils.parse_pipeline import PipelinePassSpec


class AvailablePass(NamedTuple):
Expand All @@ -16,8 +15,7 @@ class AvailablePass(NamedTuple):
"""

display_name: str
module_pass: type[ModulePass]
pass_spec: PipelinePassSpec | None
module_pass: type[ModulePass] | ModulePass


def get_new_registered_context(
Expand Down Expand Up @@ -55,19 +53,21 @@ def iter_condensed_passes(
for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)

for _, value in all_passes:
if value is MLIROptPass:
for _, pass_type in all_passes:
if pass_type is MLIROptPass:
# Always keep MLIROptPass as an option in condensed list
yield AvailablePass(value.name, value, None), None
yield AvailablePass(pass_type.name, pass_type), None
continue
cloned_module = input.clone()
cloned_ctx = ctx.clone()
try:
cloned_module = input.clone()
cloned_ctx = ctx.clone()
value().apply(cloned_ctx, cloned_module)
pass_instance = pass_type()
pass_instance.apply(cloned_ctx, cloned_module)
if input.is_structurally_equivalent(cloned_module):
continue
yield AvailablePass(value.name, value, None), cloned_module
except Exception:
pass
continue
yield AvailablePass(pass_type.name, pass_instance), cloned_module


def get_condensed_pass_list(
Expand Down
8 changes: 3 additions & 5 deletions xdsl/interactive/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,12 @@ def get_all_possible_rewrites(
rewriter = PatternRewriter(cloned_op)
pattern.match_and_rewrite(cloned_op, rewriter)
if rewriter.has_done_action:
p = individual_rewrite.ApplyIndividualRewritePass(
op_idx, cloned_op.name, pattern_name
)
res.append(
AvailablePass(
f"{cloned_op}:{cloned_op.name}:{pattern_name}",
individual_rewrite.ApplyIndividualRewritePass,
p.pipeline_pass_spec(),
individual_rewrite.ApplyIndividualRewritePass(
op_idx, cloned_op.name, pattern_name
),
)
)

Expand Down

0 comments on commit 0d39a26

Please sign in to comment.