Skip to content

Commit

Permalink
transforms: use rewriter and listener in convert-stencil-to-csl-stencil
Browse files Browse the repository at this point in the history
The pass was not propagating the listener from the PatternRewriter, and
thus some operations were modified without notifying the rewrite
worklist.
  • Loading branch information
math-fehr committed Nov 29, 2024
1 parent dae783a commit f9c786f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ builtin.module {

// CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>]}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>):
// CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32>
// CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32>
Expand Down Expand Up @@ -191,7 +191,7 @@ builtin.module {
// CHECK-NEXT: %2 = arith.constant 1 : index
// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32>
// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [2, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [-2, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, 2]>, #csl_stencil.exchange<to [0, -1]>, #csl_stencil.exchange<to [0, -2]>], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({
// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [2, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [-2, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, 2]>, #csl_stencil.exchange<to [0, -1]>, #csl_stencil.exchange<to [0, -2]>], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>]}> ({
// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>):
// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32>
// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32>
Expand Down
13 changes: 8 additions & 5 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
# replace stencil.access (operating on stencil.temp at arg_index)
# with csl_stencil.access (operating on memref at last arg index)
nested_rewriter = PatternRewriteWalker(
ConvertAccessOpFromPrefetchPattern(arg_idx)
ConvertAccessOpFromPrefetchPattern(arg_idx), listener=rewriter
)

nested_rewriter.rewrite_region(new_apply_op.region)
Expand Down Expand Up @@ -415,6 +415,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
PatternRewriteWalker(
SplitVarithOpPattern(op.region.block.args[prefetch_idx]),
apply_recursively=False,
listener=rewriter,
).rewrite_region(op.region)

# determine how ops should be split across the two regions
Expand Down Expand Up @@ -505,12 +506,13 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
# add operations from list to receive_chunk, use translation table to rebuild operands
for o in chunk_region_ops:
if isinstance(o, stencil.ReturnOp | csl_stencil.YieldOp):
rewriter.erase_op(o)
break
o.operands = [chunk_region_oprnd_table.get(x, x) for x in o.operands]
receive_chunk.block.add_op(o)
rewriter.insert_op(o, InsertPoint.at_end(receive_chunk.block))

# put `chunk_res` into `accumulator` (using tensor.insert_slice) and yield the result
receive_chunk.block.add_ops(
rewriter.insert_op(
[
insert_slice_op := tensor.InsertSliceOp.get(
source=chunk_res,
Expand All @@ -519,13 +521,14 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
static_sizes=(prefetch.type.get_shape()[1] // self.num_chunks,),
),
csl_stencil.YieldOp(insert_slice_op.result),
]
],
InsertPoint.at_end(receive_chunk.block),
)

# add operations from list to done_exchange, use translation table to rebuild operands
for o in done_exchange_ops:
o.operands = [done_exchange_oprnd_table.get(x, x) for x in o.operands]
done_exchange.block.add_op(o)
rewriter.insert_op(o, InsertPoint.at_end(done_exchange.block))
if isinstance(o, stencil.ReturnOp):
rewriter.replace_op(o, csl_stencil.YieldOp(*o.operands))

Expand Down

0 comments on commit f9c786f

Please sign in to comment.