Skip to content

Commit

Permalink
transformations: (memref-to-dsd) Support pre-existing GetMemDsd ops (#…
Browse files Browse the repository at this point in the history
…3279)

Adds support for pre-existent `GetMemDsdOp`s in the `memref-to-dsd`
pass, which is needed for #3271

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Oct 10, 2024
1 parent 964834a commit 3b844c9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
8 changes: 8 additions & 0 deletions tests/filecheck/transforms/memref-to-dsd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ builtin.module {
// CHECK-NEXT: %28 = "csl.load_var"(%27) : (!csl.var<!csl<dsd mem1d_dsd>>) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.store_var"(%27, %28) : (!csl.var<!csl<dsd mem1d_dsd>>, !csl<dsd mem1d_dsd>) -> ()

// ensure that pre-existing get_mem_dsd ops access the underlying buffer, not the get_mem_dsd created on top of it

%36 = arith.constant 510 : i16
%37 = "csl.get_mem_dsd"(%b, %36) : (memref<510xf32>, i16) -> !csl<dsd mem1d_dsd>

// CHECK-NEXT: %29 = arith.constant 510 : i16
// CHECK-NEXT: %30 = "csl.get_mem_dsd"(%b, %29) : (memref<510xf32>, i16) -> !csl<dsd mem1d_dsd>

}) {sym_name = "program"} : () -> ()
}
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
Expand Down
31 changes: 30 additions & 1 deletion xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
StridedLayoutAttr,
UnrealizedConversionCastOp,
)
from xdsl.ir import Attribute, Operation, SSAValue
from xdsl.ir import Attribute, Operation, OpResult, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -64,6 +64,31 @@ def match_and_rewrite(self, op: memref.Alloc, rewriter: PatternRewriter, /):
rewriter.replace_matched_op([zeros_op, *shape, dsd_op])


class FixGetDsdOnGetDsd(RewritePattern):
"""
This rewrite pattern resolves GetMemDsdOp being called on GetMemDsdOp instead of the underlying buffer,
a side effect created by `LowerAllocOpPass` in case of pre-existing GetMemDsdOp ops being present in
the program that were created outside of this pass.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter, /):
if isinstance(op.base_addr.type, csl.DsdType):
if isinstance(op.base_addr, OpResult) and isinstance(
op.base_addr.op, csl.GetMemDsdOp
):
rewriter.replace_matched_op(
csl.GetMemDsdOp.build(
operands=[op.base_addr.op.base_addr, op.sizes],
properties=op.properties,
attributes=op.attributes,
result_types=op.result_types,
)
)
else:
raise ValueError("Failed to resolve GetMemDsdOp called on dsd type")


class LowerSubviewOpPass(RewritePattern):
"""Lowers memref.subview to dsd ops"""

Expand Down Expand Up @@ -335,3 +360,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
apply_recursively=False,
)
forward_pass.rewrite_module(op)
cleanup_pass = PatternRewriteWalker(
FixGetDsdOnGetDsd(),
)
cleanup_pass.rewrite_module(op)

0 comments on commit 3b844c9

Please sign in to comment.