From ce853406d1ead2e13ec5c3f5e7ed20ee50100126 Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Thu, 5 Dec 2024 15:29:39 +0000 Subject: [PATCH] transforms: use rewriter and listener in convert-stencil-to-csl-stencil (#3538) Stacked PRs: * #3540 * #3539 * __->__#3538 * #3537 --- --- --- ### transforms: use rewriter and listener in convert-stencil-to-csl-stencil The pass was not propagating the listener from the PatternRewriter, and thus some operations were modified without notifying the rewrite worklist. --- .../transforms/convert-stencil-to-csl-stencil.mlir | 4 ++-- xdsl/transforms/convert_stencil_to_csl_stencil.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index 1f43e07c93..807d542516 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -135,7 +135,7 @@ builtin.module { // CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({ +// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>]}> ({ // CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): // CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32> // CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32> @@ -191,7 +191,7 @@ builtin.module { // CHECK-NEXT: %2 = arith.constant 1 : index // CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { // CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> -// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>]}> ({ // CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): // CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> // CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32> diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 555c4eb394..0a7f6c0db2 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -249,7 +249,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): # replace stencil.access (operating on stencil.temp at arg_index) # with csl_stencil.access (operating on memref at last arg index) nested_rewriter = PatternRewriteWalker( - ConvertAccessOpFromPrefetchPattern(arg_idx) + ConvertAccessOpFromPrefetchPattern(arg_idx), listener=rewriter ) nested_rewriter.rewrite_region(new_apply_op.region) @@ -415,6 +415,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): PatternRewriteWalker( SplitVarithOpPattern(op.region.block.args[prefetch_idx]), apply_recursively=False, + listener=rewriter, ).rewrite_region(op.region) # determine how ops should be split across the two regions @@ -505,12 +506,13 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): # add operations from list to receive_chunk, use translation table to rebuild operands for o in chunk_region_ops: if isinstance(o, stencil.ReturnOp | csl_stencil.YieldOp): + rewriter.erase_op(o) break o.operands = [chunk_region_oprnd_table.get(x, x) for x in o.operands] - receive_chunk.block.add_op(o) + rewriter.insert_op(o, InsertPoint.at_end(receive_chunk.block)) # put `chunk_res` into `accumulator` (using tensor.insert_slice) and yield the result - receive_chunk.block.add_ops( + rewriter.insert_op( [ insert_slice_op := tensor.InsertSliceOp.get( source=chunk_res, @@ -519,13 +521,14 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): static_sizes=(prefetch.type.get_shape()[1] // self.num_chunks,), ), csl_stencil.YieldOp(insert_slice_op.result), - ] + ], + InsertPoint.at_end(receive_chunk.block), ) # add operations from list to done_exchange, use translation table to rebuild operands for o in done_exchange_ops: o.operands = [done_exchange_oprnd_table.get(x, x) for x in o.operands] - done_exchange.block.add_op(o) + rewriter.insert_op(o, InsertPoint.at_end(done_exchange.block)) if isinstance(o, stencil.ReturnOp): rewriter.replace_op(o, csl_stencil.YieldOp(*o.operands))