Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interactive: AvailablePass now has either a pass type or instance, no spec #3725

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from xdsl.transforms.experimental.dmp import stencil_global_to_local
from xdsl.utils.exceptions import ParseError
from xdsl.utils.parse_pipeline import parse_pipeline


@pytest.mark.asyncio
Expand Down Expand Up @@ -324,12 +323,9 @@ async def test_rewrites():

addi_pass = AvailablePass(
display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:SignlessIntegerBinaryOperationZeroOrUnitRight",
module_pass=individual_rewrite.ApplyIndividualRewritePass,
pass_spec=list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
)
)[0],
module_pass=individual_rewrite.ApplyIndividualRewritePass(
3, "arith.addi", "SignlessIntegerBinaryOperationZeroOrUnitRight"
),
)

await pilot.pause()
Expand Down
17 changes: 3 additions & 14 deletions tests/interactive/test_get_all_available_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
op_type_rewrite_pattern,
)
from xdsl.transforms.individual_rewrite import ApplyIndividualRewritePass
from xdsl.utils.parse_pipeline import PipelinePassSpec


@dataclass
Expand Down Expand Up @@ -82,25 +81,15 @@ def test_get_all_available_passes():
(
AvailablePass(
display_name="bc",
module_pass=BCPass,
pass_spec=None,
module_pass=BCPass(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It looks like AvailablePass can still be built with a type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is used when "Condense Mode" is not on, and for mlir-opt pass as we want to present the pass as an option to the user without instantiating it

),
AvailablePass(
display_name="bd",
module_pass=BDPass,
pass_spec=None,
module_pass=BDPass(),
),
AvailablePass(
display_name='TestOp("test.op"() {key = "b"} : () -> ()):test.op:be',
module_pass=ApplyIndividualRewritePass,
pass_spec=PipelinePassSpec(
"apply-individual-rewrite",
{
"matched_operation_index": (1,),
"operation_name": ("test.op",),
"pattern_name": ("be",),
},
),
module_pass=ApplyIndividualRewritePass(1, "test.op", "be"),
),
)
)
21 changes: 2 additions & 19 deletions tests/interactive/test_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
op_type_rewrite_pattern,
)
from xdsl.transforms.individual_rewrite import ApplyIndividualRewritePass
from xdsl.utils.parse_pipeline import PipelinePassSpec


class Rewrite(RewritePattern):
Expand Down Expand Up @@ -44,27 +43,11 @@ def test_get_all_possible_rewrite():
expected_res = [
AvailablePass(
display_name='TestOp("test.op"() {label = "a"} : () -> ()):test.op:TestRewrite',
module_pass=ApplyIndividualRewritePass,
pass_spec=PipelinePassSpec(
"apply-individual-rewrite",
{
"matched_operation_index": (1,),
"operation_name": ("test.op",),
"pattern_name": ("TestRewrite",),
},
),
module_pass=ApplyIndividualRewritePass(1, "test.op", "TestRewrite"),
),
AvailablePass(
display_name='TestOp("test.op"() {label = "a"} : () -> ()):test.op:TestRewrite',
module_pass=ApplyIndividualRewritePass,
pass_spec=PipelinePassSpec(
"apply-individual-rewrite",
{
"matched_operation_index": (2,),
"operation_name": ("test.op",),
"pattern_name": ("TestRewrite",),
},
),
module_pass=ApplyIndividualRewritePass(2, "test.op", "TestRewrite"),
),
]

Expand Down
8 changes: 5 additions & 3 deletions xdsl/interactive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def compute_available_pass_list(self) -> tuple[AvailablePass, ...]:
"""
match self.current_module:
case None:
return tuple(AvailablePass(p.name, p, None) for _, p in self.all_passes)
return tuple(AvailablePass(p.name, p) for _, p in self.all_passes)
case Exception():
return ()
case ModuleOp():
Expand Down Expand Up @@ -347,10 +347,12 @@ def expand_node(
# remove potential children nodes in case expand node has been clicked multiple times on the same node
expanded_pass.remove_children()

for pass_name, value, value_spec in child_pass_list:
for pass_name, value in child_pass_list:
expanded_pass.add(
label=pass_name,
data=(value, value_spec),
data=(type(value), value.pipeline_pass_spec())
if isinstance(value, ModulePass)
else (value, None),
Comment on lines +353 to +355
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a newly introduced ugly bit due to the tree node data still being in terms of pass type and spec, next PR will remove this.

)

def update_root_of_passes_tree(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion xdsl/interactive/get_all_available_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ def get_available_pass_list(
if condense_mode:
pass_list = get_condensed_pass_list(current_module, all_passes)
else:
pass_list = tuple(AvailablePass(p.name, p, None) for _, p in all_passes)
pass_list = tuple(AvailablePass(p.name, p) for _, p in all_passes)
return pass_list + tuple(individual_rewrites)
22 changes: 11 additions & 11 deletions xdsl/interactive/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xdsl.ir import Dialect
from xdsl.passes import ModulePass, PipelinePass
from xdsl.transforms.mlir_opt import MLIROptPass
from xdsl.utils.parse_pipeline import PipelinePassSpec


class AvailablePass(NamedTuple):
Expand All @@ -16,8 +15,7 @@ class AvailablePass(NamedTuple):
"""

display_name: str
module_pass: type[ModulePass]
pass_spec: PipelinePassSpec | None
module_pass: type[ModulePass] | ModulePass


def get_new_registered_context(
Expand Down Expand Up @@ -55,19 +53,21 @@ def iter_condensed_passes(
for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)

for _, value in all_passes:
if value is MLIROptPass:
for _, pass_type in all_passes:
if pass_type is MLIROptPass:
# Always keep MLIROptPass as an option in condensed list
yield AvailablePass(value.name, value, None), None
yield AvailablePass(pass_type.name, pass_type), None
continue
cloned_module = input.clone()
cloned_ctx = ctx.clone()
try:
cloned_module = input.clone()
cloned_ctx = ctx.clone()
value().apply(cloned_ctx, cloned_module)
pass_instance = pass_type()
pass_instance.apply(cloned_ctx, cloned_module)
if input.is_structurally_equivalent(cloned_module):
continue
yield AvailablePass(value.name, value, None), cloned_module
except Exception:
pass
continue
yield AvailablePass(pass_type.name, pass_instance), cloned_module


def get_condensed_pass_list(
Expand Down
8 changes: 3 additions & 5 deletions xdsl/interactive/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,12 @@ def get_all_possible_rewrites(
rewriter = PatternRewriter(cloned_op)
pattern.match_and_rewrite(cloned_op, rewriter)
if rewriter.has_done_action:
p = individual_rewrite.ApplyIndividualRewritePass(
op_idx, cloned_op.name, pattern_name
)
res.append(
AvailablePass(
f"{cloned_op}:{cloned_op.name}:{pattern_name}",
individual_rewrite.ApplyIndividualRewritePass,
p.pipeline_pass_spec(),
individual_rewrite.ApplyIndividualRewritePass(
op_idx, cloned_op.name, pattern_name
),
)
)

Expand Down
Loading