Skip to content

Commit

Permalink
Merge branch 'main' into emilien/dmp-strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal authored Aug 18, 2024
2 parents 193143d + 8adc82d commit 3bb67ab
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 33 deletions.
18 changes: 9 additions & 9 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,13 @@ class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter):
if matched_op.regs and matched_op.regs[0].blocks:
rewriter.modify_block_argument_type(
matched_op.regs[0].blocks[0].args[0], i64
)
rewriter.modify_value_type(matched_op.regs[0].blocks[0].args[0], i64)

rewrite_and_compare(
prog,
expected,
PatternRewriteWalker(Rewrite(), apply_recursively=False),
op_modified=1,
)


Expand Down Expand Up @@ -1423,7 +1422,7 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_inserted=5,
op_removed=5,
op_replaced=5,
op_modified=4,
op_modified=5,
)
rewrite_and_compare(
prog,
Expand All @@ -1432,7 +1431,7 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_inserted=5,
op_removed=5,
op_replaced=5,
op_modified=4,
op_modified=5,
)
rewrite_and_compare(
prog,
Expand All @@ -1443,7 +1442,7 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_inserted=5,
op_removed=5,
op_replaced=5,
op_modified=4,
op_modified=5,
)

non_rec_expected = """\
Expand All @@ -1467,7 +1466,7 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_inserted=2,
op_removed=2,
op_replaced=2,
op_modified=3,
op_modified=4,
)
rewrite_and_compare(
prog,
Expand All @@ -1476,7 +1475,7 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_inserted=2,
op_removed=2,
op_replaced=2,
op_modified=3,
op_modified=4,
)
rewrite_and_compare(
prog,
Expand All @@ -1485,7 +1484,7 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_inserted=2,
op_removed=2,
op_replaced=2,
op_modified=3,
op_modified=4,
)

expected = """\
Expand Down Expand Up @@ -1640,6 +1639,7 @@ def convert_type(self, typ: IntegerType) -> IndexType:
op_inserted=1,
op_removed=1,
op_replaced=1,
op_modified=1,
)


Expand Down
11 changes: 8 additions & 3 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,16 @@ def replace_op(
# Then, erase the original operation
self.erase_op(op, safe_erase=safe_erase)

def modify_block_argument_type(self, arg: BlockArgument, new_type: Attribute):
"""Modify the type of a block argument."""
def modify_value_type(self, arg: SSAValue, new_type: Attribute):
"""Modify the type of a value."""
self.has_done_action = True
arg.type = new_type

owner = arg.owner
if isinstance(owner, Block):
owner = owner.parent_op()
if owner is not None:
self.handle_operation_modification(owner)
for use in arg.uses:
self.handle_operation_modification(use.operation)

Expand Down Expand Up @@ -523,7 +528,7 @@ def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
for arg in block.args:
converted = self._convert_type_rec(arg.type)
if converted is not None and converted != arg.type:
rewriter.modify_block_argument_type(arg, converted)
rewriter.modify_value_type(arg, converted)
if changed:
regions = [op.detach_region(r) for r in op.regions]
new_op = type(op).create(
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/convert_memref_stream_to_snitch_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def match_and_rewrite(
)
arg.replace_by(cast_op.results[0])
cast_op.operands = (arg,)
rewriter.modify_block_argument_type(arg, stream_type)
rewriter.modify_value_type(arg, stream_type)


def strides_for_affine_map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):

if n_dims == 3:
stream = self.shift_streams[current_stream][k]
rewriter.modify_block_argument_type(
rewriter.modify_value_type(
apply_clone.region.block.args[i], stream.results[0].type
)

Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/memref_stream_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def match_and_rewrite(
for i, arg in enumerate(new_body.block.args):
if i not in legalizations:
continue
rewriter.modify_block_argument_type(arg, legalizations[i])
rewriter.modify_value_type(arg, legalizations[i])
to_be_legalized.update(use.operation for use in arg.uses)
# Legalize payload
_legalize_block(new_body.block, to_be_legalized, rewriter)
Expand Down
21 changes: 3 additions & 18 deletions xdsl/transforms/shape_inference_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StoreOp,
TempType,
)
from xdsl.ir import Attribute, Block, Operation, SSAValue
from xdsl.ir import Attribute, Operation, SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
Expand Down Expand Up @@ -59,21 +59,6 @@ def infer_core_size(op: LoadOp) -> tuple[IndexAttr, IndexAttr]:
return shape_lb, shape_ub


def modify_value_type(value: SSAValue, new_type: Attribute, rewriter: PatternRewriter):
"""Modify the type of a value, triggering more rewrites on its uses and owner."""
rewriter.has_done_action = True
value.type = new_type

for use in value.uses:
rewriter.handle_operation_modification(use.operation)

owner = value.owner
if isinstance(owner, Block):
owner = owner.parent_op()
if owner is not None:
rewriter.handle_operation_modification(owner)


def update_result_size(
value: SSAValue, size: StencilBoundsAttr, rewriter: PatternRewriter
):
Expand Down Expand Up @@ -103,14 +88,14 @@ def update_result_size(
newsize, cast(TempType[Attribute], res.type).element_type
)
if newtype != res.type:
modify_value_type(res, newtype, rewriter)
rewriter.modify_value_type(res, newtype)
for use in res.uses:
if isinstance(use.operation, BufferOp):
update_result_size(use.operation.res, newsize, rewriter)
newsize = size | cast(TempType[Attribute], value.type).bounds
newtype = TempType(newsize, cast(TempType[Attribute], value.type).element_type)
if newtype != value.type:
modify_value_type(value, newtype, rewriter)
rewriter.modify_value_type(value, newtype)


class CombineOpShapeInference(RewritePattern):
Expand Down

0 comments on commit 3bb67ab

Please sign in to comment.