From 47040259f4f26df07640a927bb1686a535c32608 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Mon, 19 Aug 2024 13:57:10 +0100 Subject: [PATCH] dialects: Update dmp.swap (#3056) Give it a proper signature for value-semantics and reference-semantics. It now needs to return values if taking values in, and need only an in-place operand if reference-semantics. Update the distribution pass, to replace the loaded values by the swapped values, building a consistent def-use chain. --------- Co-authored-by: n-io --- .../filecheck/dialects/dmp/canonicalize.mlir | 16 ++--- tests/filecheck/dialects/dmp/ops.mlir | 17 +++++- .../transforms/distribute-stencil.mlir | 60 +++++++++---------- .../transforms/stencil-to-csl-stencil.mlir | 5 +- xdsl/dialects/csl/csl_stencil.py | 2 +- xdsl/dialects/experimental/dmp.py | 34 +++++++++-- xdsl/transforms/canonicalize_dmp.py | 9 ++- .../dmp/stencil_global_to_local.py | 6 ++ .../shape_inference_patterns/dmp.py | 22 ++++++- xdsl/transforms/stencil_to_csl_stencil.py | 2 +- 10 files changed, 120 insertions(+), 53 deletions(-) diff --git a/tests/filecheck/dialects/dmp/canonicalize.mlir b/tests/filecheck/dialects/dmp/canonicalize.mlir index 200c83ba4c..e2fdb737b0 100644 --- a/tests/filecheck/dialects/dmp/canonicalize.mlir +++ b/tests/filecheck/dialects/dmp/canonicalize.mlir @@ -1,7 +1,8 @@ // RUN: xdsl-opt -p canonicalize-dmp %s | filecheck %s builtin.module { - %ref = "test.op"() : () -> (memref<1024x1024xf32>) + %ref = "test.op"() : () -> (!stencil.field<[0,1024]x[0,1024]xf32>) + %val = "test.op"() : () -> (!stencil.temp<[0,1024]x[0,1024]xf32>) "dmp.swap"(%ref) { strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, @@ -9,20 +10,21 @@ builtin.module { #dmp.exchange, #dmp.exchange ] - } : (memref<1024x1024xf32>) -> () + } : (!stencil.field<[0,1024]x[0,1024]xf32>) -> () - // CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange]} : (memref<1024x1024xf32>) -> () + // CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange]} : (!stencil.field<[0,1024]x[0,1024]xf32>) -> () - "dmp.swap"(%ref) { + %swap_val = "dmp.swap"(%val) { strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, swaps = [ #dmp.exchange, #dmp.exchange ] - } : (memref<1024x1024xf32>) -> () + } : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> (!stencil.temp<[0,1024]x[0,1024]xf32>) - "test.op"() : () -> () + "test.op"(%swap_val) : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> () // this op should be completely removed since both exchanges are empty, so we expect the next op to be a test.op - // CHECK-NEXT: "test.op"() : () -> () + // and its operand replaced by the unswapped input + // CHECK-NEXT: "test.op"(%val) : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> () } diff --git a/tests/filecheck/dialects/dmp/ops.mlir b/tests/filecheck/dialects/dmp/ops.mlir index 31f1de33fb..41d8515207 100644 --- a/tests/filecheck/dialects/dmp/ops.mlir +++ b/tests/filecheck/dialects/dmp/ops.mlir @@ -1,7 +1,8 @@ // RUN: XDSL_ROUNDTRIP builtin.module { - %ref = "test.op"() : () -> (memref<1024x1024xf32>) + %ref = "test.op"() : () -> (!stencil.field<[0,1024]x[0,1024]xf32>) + %val = "test.op"() : () -> (!stencil.temp<[0,1024]x[0,1024]xf32>) "dmp.swap"(%ref) { strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, @@ -9,7 +10,17 @@ builtin.module { #dmp.exchange, #dmp.exchange ] - } : (memref<1024x1024xf32>) -> () + } : (!stencil.field<[0,1024]x[0,1024]xf32>) -> () - // CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange, #dmp.exchange]} : (memref<1024x1024xf32>) -> () + + %swap_val = "dmp.swap"(%val) { + strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, + swaps = [ + #dmp.exchange, + #dmp.exchange + ] + } : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> (!stencil.temp<[0,1024]x[0,1024]xf32>) + + // CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange, #dmp.exchange]} : (!stencil.field<[0,1024]x[0,1024]xf32>) -> () + // CHECK-NEXT: %swap_val = "dmp.swap"(%val) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange, #dmp.exchange]} : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> !stencil.temp<[0,1024]x[0,1024]xf32> } diff --git a/tests/filecheck/transforms/distribute-stencil.mlir b/tests/filecheck/transforms/distribute-stencil.mlir index bfe6e394b0..0ac4fdad35 100644 --- a/tests/filecheck/transforms/distribute-stencil.mlir +++ b/tests/filecheck/transforms/distribute-stencil.mlir @@ -24,45 +24,45 @@ // CHECK: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { // CHECK-NEXT: %3 = stencil.load %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp -// CHECK-NEXT: "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = []} : (!stencil.temp) -> () -// CHECK-NEXT: %4, %5 = stencil.apply(%6 = %3 : !stencil.temp) -> (!stencil.temp, !stencil.temp) { -// CHECK-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.temp -// CHECK-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.temp -// CHECK-NEXT: %9 = stencil.access %6[0, 1, 0] : !stencil.temp -// CHECK-NEXT: %10 = stencil.access %6[0, -1, 0] : !stencil.temp -// CHECK-NEXT: %11 = stencil.access %6[0, 0, 0] : !stencil.temp -// CHECK-NEXT: %12 = arith.addf %7, %8 : f64 -// CHECK-NEXT: %13 = arith.addf %9, %10 : f64 -// CHECK-NEXT: %14 = arith.addf %12, %13 : f64 +// CHECK-NEXT: %4 = "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = []} : (!stencil.temp) -> !stencil.temp +// CHECK-NEXT: %5, %6 = stencil.apply(%7 = %4 : !stencil.temp) -> (!stencil.temp, !stencil.temp) { +// CHECK-NEXT: %8 = stencil.access %7[-1, 0, 0] : !stencil.temp +// CHECK-NEXT: %9 = stencil.access %7[1, 0, 0] : !stencil.temp +// CHECK-NEXT: %10 = stencil.access %7[0, 1, 0] : !stencil.temp +// CHECK-NEXT: %11 = stencil.access %7[0, -1, 0] : !stencil.temp +// CHECK-NEXT: %12 = stencil.access %7[0, 0, 0] : !stencil.temp +// CHECK-NEXT: %13 = arith.addf %8, %9 : f64 +// CHECK-NEXT: %14 = arith.addf %10, %11 : f64 +// CHECK-NEXT: %15 = arith.addf %13, %14 : f64 // CHECK-NEXT: %cst = arith.constant -4.000000e+00 : f64 -// CHECK-NEXT: %15 = arith.mulf %11, %cst : f64 -// CHECK-NEXT: %16 = arith.addf %15, %14 : f64 -// CHECK-NEXT: stencil.return %16, %15 : f64, f64 +// CHECK-NEXT: %16 = arith.mulf %12, %cst : f64 +// CHECK-NEXT: %17 = arith.addf %16, %15 : f64 +// CHECK-NEXT: stencil.return %17, %16 : f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %4 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: stencil.store %5 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %6 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } // SHAPE: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { // SHAPE-NEXT: %3 = stencil.load %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> -// SHAPE-NEXT: "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> () -// SHAPE-NEXT: %4, %5 = stencil.apply(%6 = %3 : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> (!stencil.temp<[0,32]x[0,32]x[0,32]xf64>, !stencil.temp<[0,32]x[0,32]x[0,32]xf64>) { -// SHAPE-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> -// SHAPE-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> -// SHAPE-NEXT: %9 = stencil.access %6[0, 1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> -// SHAPE-NEXT: %10 = stencil.access %6[0, -1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> -// SHAPE-NEXT: %11 = stencil.access %6[0, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> -// SHAPE-NEXT: %12 = arith.addf %7, %8 : f64 -// SHAPE-NEXT: %13 = arith.addf %9, %10 : f64 -// SHAPE-NEXT: %14 = arith.addf %12, %13 : f64 +// SHAPE-NEXT: %4 = "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> +// SHAPE-NEXT: %5, %6 = stencil.apply(%7 = %4 : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> (!stencil.temp<[0,32]x[0,32]x[0,32]xf64>, !stencil.temp<[0,32]x[0,32]x[0,32]xf64>) { +// SHAPE-NEXT: %8 = stencil.access %7[-1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> +// SHAPE-NEXT: %9 = stencil.access %7[1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> +// SHAPE-NEXT: %10 = stencil.access %7[0, 1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> +// SHAPE-NEXT: %11 = stencil.access %7[0, -1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> +// SHAPE-NEXT: %12 = stencil.access %7[0, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> +// SHAPE-NEXT: %13 = arith.addf %8, %9 : f64 +// SHAPE-NEXT: %14 = arith.addf %10, %11 : f64 +// SHAPE-NEXT: %15 = arith.addf %13, %14 : f64 // SHAPE-NEXT: %cst = arith.constant -4.000000e+00 : f64 -// SHAPE-NEXT: %15 = arith.mulf %11, %cst : f64 -// SHAPE-NEXT: %16 = arith.addf %15, %14 : f64 -// SHAPE-NEXT: stencil.return %16, %15 : f64, f64 +// SHAPE-NEXT: %16 = arith.mulf %12, %cst : f64 +// SHAPE-NEXT: %17 = arith.addf %16, %15 : f64 +// SHAPE-NEXT: stencil.return %17, %16 : f64, f64 // SHAPE-NEXT: } -// SHAPE-NEXT: stencil.store %4 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// SHAPE-NEXT: stencil.store %5 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// SHAPE-NEXT: stencil.store %5 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// SHAPE-NEXT: stencil.store %6 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // SHAPE-NEXT: func.return // SHAPE-NEXT: } diff --git a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/stencil-to-csl-stencil.mlir index 7868c2c279..ccbf3a40a3 100644 --- a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/stencil-to-csl-stencil.mlir @@ -1,10 +1,11 @@ // RUN: xdsl-opt %s -p "stencil-to-csl-stencil{num_chunks=2}" | filecheck %s builtin.module { + func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - "dmp.swap"(%0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> () - %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { + %24 = "dmp.swap"(%0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) + %1 = stencil.apply(%2 = %24 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { %3 = arith.constant 1.666600e-01 : f32 %4 = stencil.access %2[1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> %5 = "tensor.extract_slice"(%4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index 2c3015b159..4eb97404bc 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -124,7 +124,7 @@ class PrefetchOp(IRDLOperation): name = "csl_stencil.prefetch" input_stencil = operand_def( - base(stencil.TempType[Attribute]) + base(stencil.StencilType[Attribute]) | base(memref.MemRefType[Attribute]) | base(TensorType[Attribute]) ) diff --git a/xdsl/dialects/experimental/dmp.py b/xdsl/dialects/experimental/dmp.py index 143af94694..382f19b3a8 100644 --- a/xdsl/dialects/experimental/dmp.py +++ b/xdsl/dialects/experimental/dmp.py @@ -22,14 +22,15 @@ Operand, ParameterDef, attr_def, - base, irdl_attr_definition, irdl_op_definition, operand_def, + opt_result_def, ) from xdsl.parser import AttrParser from xdsl.printer import Printer from xdsl.traits import HasShapeInferencePatternsTrait +from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa # helpers for named dimensions: @@ -585,9 +586,12 @@ def coords(where: Literal["start", "end"]): class SwapOpHasShapeInferencePatterns(HasShapeInferencePatternsTrait): @classmethod def get_shape_inference_patterns(cls): - from xdsl.transforms.shape_inference_patterns.dmp import DmpSwapShapeInference + from xdsl.transforms.shape_inference_patterns.dmp import ( + DmpSwapShapeInference, + DmpSwapSwapsInference, + ) - return (DmpSwapShapeInference(),) + return (DmpSwapShapeInference(), DmpSwapSwapsInference()) @irdl_op_definition @@ -598,9 +602,8 @@ class SwapOp(IRDLOperation): name = "dmp.swap" - input_stencil: Operand = operand_def( - base(stencil.AnyTempType) | base(builtin.AnyMemRefType) - ) + input_stencil: Operand = operand_def(stencil.StencilType[Attribute]) + swapped_values = opt_result_def(stencil.TempType[Attribute]) swaps = attr_def(builtin.ArrayAttr[ExchangeDeclarationAttr]) @@ -608,10 +611,29 @@ class SwapOp(IRDLOperation): traits = frozenset([SwapOpHasShapeInferencePatterns()]) + def verify_(self) -> None: + if self.swapped_values: + if isinstance(self.input_stencil.type, stencil.FieldType): + raise VerifyException( + "dmp.swap_op cannot have a result if input is a field" + ) + else: + if isinstance(self.input_stencil.type, stencil.TempType): + raise VerifyException( + "dmp.swap_op must have a result if input is a temporary" + ) + @staticmethod def get(input_stencil: SSAValue | Operation, strategy: DomainDecompositionStrategy): + input_type = SSAValue.get(input_stencil).type + + result_types = ( + input_type if isa(input_type, stencil.TempType[Attribute]) else None + ) + return SwapOp.build( operands=[input_stencil], + result_types=[result_types], attributes={ "strategy": strategy, "swaps": builtin.ArrayAttr[ExchangeDeclarationAttr](()), diff --git a/xdsl/transforms/canonicalize_dmp.py b/xdsl/transforms/canonicalize_dmp.py index 01b9a5b74d..312bb626c3 100644 --- a/xdsl/transforms/canonicalize_dmp.py +++ b/xdsl/transforms/canonicalize_dmp.py @@ -1,4 +1,4 @@ -from xdsl.dialects import builtin +from xdsl.dialects import builtin, stencil from xdsl.dialects.experimental import dmp from xdsl.passes import MLContext, ModulePass from xdsl.pattern_rewriter import ( @@ -17,7 +17,12 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): if swap.elem_count > 0: keeps.append(swap) if len(keeps) == 0: - rewriter.erase_matched_op() + new_result = ( + op.input_stencil + if isinstance(op.input_stencil.type, stencil.TempType) + else None + ) + rewriter.replace_matched_op([], [new_result]) else: op.swaps = builtin.ArrayAttr(keeps) diff --git a/xdsl/transforms/experimental/dmp/stencil_global_to_local.py b/xdsl/transforms/experimental/dmp/stencil_global_to_local.py index 54b1f93646..c24b86fc95 100644 --- a/xdsl/transforms/experimental/dmp/stencil_global_to_local.py +++ b/xdsl/transforms/experimental/dmp/stencil_global_to_local.py @@ -56,7 +56,13 @@ class AddHaloExchangeOps(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: stencil.LoadOp, rewriter: PatternRewriter, /): swap_op = dmp.SwapOp.get(op.res, self.strategy) + assert swap_op.swapped_values rewriter.insert_op_after_matched_op(swap_op) + for use in op.res.uses.copy(): + if use.operation is swap_op: + continue + use.operation.operands[use.index] = swap_op.swapped_values + rewriter.handle_operation_modification(use.operation) @dataclass diff --git a/xdsl/transforms/shape_inference_patterns/dmp.py b/xdsl/transforms/shape_inference_patterns/dmp.py index cce31ebc21..86d204db48 100644 --- a/xdsl/transforms/shape_inference_patterns/dmp.py +++ b/xdsl/transforms/shape_inference_patterns/dmp.py @@ -10,6 +10,23 @@ class DmpSwapShapeInference(RewritePattern): + """ + Infer the shape of the `dmp.swap` operation. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter): + if not op.swapped_values: + return + swap_t = op.swapped_values.type + if not isinstance(swap_t, stencil.TempType): + return + if op.input_stencil.type != swap_t: + op.input_stencil.type = swap_t + rewrite.handle_operation_modification(op) + + +class DmpSwapSwapsInference(RewritePattern): """ Infer the exact exchanges this `dmp.swap` needs to perform. """ @@ -19,7 +36,10 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter): core_lb: stencil.IndexAttr | None = None core_ub: stencil.IndexAttr | None = None - for use in op.input_stencil.uses: + if not op.swapped_values: + return + + for use in op.swapped_values.uses: if not isinstance(use.operation, stencil.ApplyOp): continue assert use.operation.res diff --git a/xdsl/transforms/stencil_to_csl_stencil.py b/xdsl/transforms/stencil_to_csl_stencil.py index afa8db29f0..8c30c0177a 100644 --- a/xdsl/transforms/stencil_to_csl_stencil.py +++ b/xdsl/transforms/stencil_to_csl_stencil.py @@ -230,7 +230,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): ) # a little hack to get around a check that prevents replacing a no-results op with an n-results op - rewriter.replace_matched_op(prefetch_op, new_results=[]) + rewriter.replace_matched_op(prefetch_op, new_results=[op.input_stencil]) # uses have to be retrieved *before* the loop because of the rewriting happening inside the loop uses = list(op.input_stencil.uses)