From e70919f57d630254f422a2c2ff5581642eb5208c Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 16 Aug 2024 14:56:36 +0100 Subject: [PATCH 1/5] core: Improved treatment of whitespace in assembly format (#3044) This makes the following changes: - Makes assembly format recognise the empty whitespace directive ``` `` ``` - Fixes a bug in the assembly format parsing so that this is actually parsed correctly - Allows whitespace at the start of an optional group - Modifies existing stencil tests now that they print correctly (i.e. without a space in `stencil.store` - Adds tests for new functionality These changes bring the assembly format closer to the assembly format of mlir --- tests/dialects/test_stencil.py | 2 +- .../csl/csl-stencil-canonicalize.mlir | 12 +++--- .../dialects/csl/csl-stencil-ops.mlir | 8 ++-- .../dialects/stencil/canonicalize.mlir | 6 +-- .../stencil/oec-kernels/fvtp2d_qi.mlir | 8 ++-- .../dialects/stencil/stencil_ops.mlir | 24 ++++++------ .../with-mlir/dialects/stencil/ops.mlir | 6 +-- .../csl-stencil-to-csl-wrapper.mlir | 4 +- .../transforms/distribute-stencil.mlir | 26 ++++++------- .../transforms/stencil-shape-inference.mlir | 38 +++++++++---------- .../stencil-storage-materialization.mlir | 20 +++++----- .../stencil-tensorize-z-dimension.mlir | 6 +-- .../transforms/stencil-to-csl-stencil.mlir | 4 +- .../filecheck/transforms/stencil-unroll.mlir | 16 ++++---- .../irdl/test_declarative_assembly_format.py | 34 ++++++++++++++++- xdsl/irdl/declarative_assembly_format.py | 13 +++++-- .../declarative_assembly_format_parser.py | 37 +++++++++++++----- 17 files changed, 160 insertions(+), 104 deletions(-) diff --git a/tests/dialects/test_stencil.py b/tests/dialects/test_stencil.py index 350131ab98..ed4531f2ce 100644 --- a/tests/dialects/test_stencil.py +++ b/tests/dialects/test_stencil.py @@ -756,7 +756,7 @@ def module(): %9 = arith.addf %8, %7 : f32 stencil.return %9 : f32 } - stencil.store %3 to %1 (<[0], [6]>) : !stencil.temp to !stencil.field<[-1,7]xf32> + stencil.store %3 to %1(<[0], [6]>) : !stencil.temp to !stencil.field<[-1,7]xf32> func.return } } diff --git a/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir b/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir index ed37d8aba9..6a95b40014 100644 --- a/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir +++ b/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir @@ -15,7 +15,7 @@ builtin.module { ^0(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>): csl_stencil.yield %9 : tensor<510xf32> }) - stencil.store %2 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> %10 = tensor.empty() : tensor<510xf32> %11 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ @@ -27,7 +27,7 @@ builtin.module { ^0(%17 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %18 : tensor<510xf32>): csl_stencil.yield %18 : tensor<510xf32> }) - stencil.store %11 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + stencil.store %11 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> %19 = tensor.empty() : tensor<510xf32> %20 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %19 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ @@ -39,7 +39,7 @@ builtin.module { ^0(%26 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %27 : tensor<510xf32>): csl_stencil.yield %27 : tensor<510xf32> }) - stencil.store %20 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + stencil.store %20 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> func.return } } @@ -58,7 +58,7 @@ builtin.module { // CHECK-NEXT: ^1(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>): // CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32> // CHECK-NEXT: }) -// CHECK-NEXT: stencil.store %2 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %3 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%4 : tensor<4x255xf32>, %5 : index, %6 : tensor<510xf32>): // CHECK-NEXT: %7 = csl_stencil.access %4[1, 0] : tensor<4x255xf32> @@ -68,7 +68,7 @@ builtin.module { // CHECK-NEXT: ^1(%9 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>): // CHECK-NEXT: csl_stencil.yield %10 : tensor<510xf32> // CHECK-NEXT: }) -// CHECK-NEXT: stencil.store %3 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %3 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %4 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%5 : tensor<4x255xf32>, %6 : index, %7 : tensor<510xf32>): // CHECK-NEXT: %8 = csl_stencil.access %5[1, 0] : tensor<4x255xf32> @@ -78,7 +78,7 @@ builtin.module { // CHECK-NEXT: ^1(%10 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %11 : tensor<510xf32>): // CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32> // CHECK-NEXT: }) -// CHECK-NEXT: stencil.store %4 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %4 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir index 599a35e00a..c7fe8e65d9 100644 --- a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir +++ b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir @@ -25,7 +25,7 @@ builtin.module { %20 = arith.mulf %17, %19 : tensor<510xf32> stencil.return %20 : tensor<510xf32> } - stencil.store %1 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> func.return } } @@ -54,7 +54,7 @@ builtin.module { // CHECK-NEXT: %20 = arith.mulf %17, %19 : tensor<510xf32> // CHECK-NEXT: stencil.return %20 : tensor<510xf32> // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %1 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } @@ -133,7 +133,7 @@ builtin.module { csl_stencil.yield %21 : tensor<510xf32> }) - stencil.store %2 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> func.return } } @@ -167,7 +167,7 @@ builtin.module { // CHECK-NEXT: %21 = arith.mulf %17, %20 : tensor<510xf32> // CHECK-NEXT: csl_stencil.yield %21 : tensor<510xf32> // CHECK-NEXT: }) -// CHECK-NEXT: stencil.store %2 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/stencil/canonicalize.mlir b/tests/filecheck/dialects/stencil/canonicalize.mlir index 4987bfe37a..ed92e65ab9 100644 --- a/tests/filecheck/dialects/stencil/canonicalize.mlir +++ b/tests/filecheck/dialects/stencil/canonicalize.mlir @@ -18,8 +18,8 @@ func.func @dup_operand(%f : !stencil.field<[0,64]xf64>, %of1 : !stencil.field<[0 // CHECK-NEXT: %0 = stencil.access %one[0] : !stencil.temp // CHECK-NEXT: stencil.return %0, %0 : f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %o1 to %of1 (<[0], [64]>) : !stencil.temp to !stencil.field<[0,64]xf64> -// CHECK-NEXT: stencil.store %o2 to %of2 (<[0], [64]>) : !stencil.temp to !stencil.field<[0,64]xf64> +// CHECK-NEXT: stencil.store %o1 to %of1(<[0], [64]>) : !stencil.temp to !stencil.field<[0,64]xf64> +// CHECK-NEXT: stencil.store %o2 to %of2(<[0], [64]>) : !stencil.temp to !stencil.field<[0,64]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -44,6 +44,6 @@ func.func @unused_res(%f1 : !stencil.field<[0,64]xf64>, %f2 : !stencil.field<[0, // CHECK-NEXT: %0 = stencil.access %one[0] : !stencil.temp // CHECK-NEXT: stencil.return %0 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %o1 to %of (<[0], [64]>) : !stencil.temp to !stencil.field<[0,64]xf64> +// CHECK-NEXT: stencil.store %o1 to %of(<[0], [64]>) : !stencil.temp to !stencil.field<[0,64]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir index 9abed5e24e..b94016cd75 100644 --- a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir +++ b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir @@ -227,8 +227,8 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // CHECK-NEXT: stencil.return %29 : !stencil.result // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %17 to %6 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: stencil.store %19 to %5 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %17 to %6(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %19 to %5(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -348,8 +348,8 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // SHAPE-NEXT: stencil.return %35 : !stencil.result // SHAPE-NEXT: } -// SHAPE-NEXT: stencil.store %22 to %6 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,65]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// SHAPE-NEXT: stencil.store %25 to %5 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// SHAPE-NEXT: stencil.store %22 to %6(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,65]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// SHAPE-NEXT: stencil.store %25 to %5(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // SHAPE-NEXT: func.return // SHAPE-NEXT: } diff --git a/tests/filecheck/dialects/stencil/stencil_ops.mlir b/tests/filecheck/dialects/stencil/stencil_ops.mlir index 359866b512..88618fb8f5 100644 --- a/tests/filecheck/dialects/stencil/stencil_ops.mlir +++ b/tests/filecheck/dialects/stencil/stencil_ops.mlir @@ -10,7 +10,7 @@ builtin.module { %8 = stencil.store_result %7 : !stencil.result stencil.return %8 : !stencil.result } - stencil.store %5 to %3 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> + stencil.store %5 to %3(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> func.return } } @@ -25,7 +25,7 @@ builtin.module { // CHECK-NEXT: %8 = stencil.store_result %7 : !stencil.result // CHECK-NEXT: stencil.return %8 : !stencil.result // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %5 to %3 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %3(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -44,7 +44,7 @@ builtin.module { %v = stencil.access %ti_[0, 0, 0] : !stencil.temp stencil.return %v : f32 } - stencil.store %tip1 to %fip1 (<[0, 0, 0], [50, 80, 40]>) : !stencil.temp to !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32> + stencil.store %tip1 to %fip1(<[0, 0, 0], [50, 80, 40]>) : !stencil.temp to !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32> scf.yield %fip1, %fi : !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32>, !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32> } func.return @@ -63,7 +63,7 @@ builtin.module { // CHECK-NEXT: %v = stencil.access %ti_[0, 0, 0] : !stencil.temp // CHECK-NEXT: stencil.return %v : f32 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %tip1 to %fip1 (<[0, 0, 0], [50, 80, 40]>) : !stencil.temp to !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32> +// CHECK-NEXT: stencil.store %tip1 to %fip1(<[0, 0, 0], [50, 80, 40]>) : !stencil.temp to !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32> // CHECK-NEXT: scf.yield %fip1, %fi : !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32>, !stencil.field<[-4,54]x[-4,84]x[-4,44]xf32> // CHECK-NEXT: } // CHECK-NEXT: func.return @@ -90,7 +90,7 @@ builtin.module { %17 = arith.mulf %16, %13 : f64 stencil.return %17 : f64 } - stencil.store %5 to %3 (<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> + stencil.store %5 to %3(<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> func.return } } @@ -114,7 +114,7 @@ builtin.module { // CHECK-NEXT: %17 = arith.mulf %16, %13 : f64 // CHECK-NEXT: stencil.return %17 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %5 to %3 (<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %3(<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -133,7 +133,7 @@ builtin.module { stencil.return %9 : f64 } %10 = stencil.combine 0 at 11 lower = (%6 : !stencil.temp) upper = (%7 : !stencil.temp) : !stencil.temp - stencil.store %10 to %1 (<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> + stencil.store %10 to %1(<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> func.return } } @@ -151,7 +151,7 @@ builtin.module { // CHECK-NEXT: stencil.return %7 : f64 // CHECK-NEXT: } // CHECK-NEXT: %6 = stencil.combine 0 at 11 lower = (%4 : !stencil.temp) upper = (%5 : !stencil.temp) : !stencil.temp -// CHECK-NEXT: stencil.store %6 to %1 (<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> +// CHECK-NEXT: stencil.store %6 to %1(<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -167,7 +167,7 @@ builtin.module { %8 = stencil.access %6[0, _] : !stencil.temp<[-1,65]xf64> stencil.return %8 : f64 } - stencil.store %5 to %3 (<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> + stencil.store %5 to %3(<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> func.return } } @@ -182,7 +182,7 @@ builtin.module { // CHECK-NEXT: %8 = stencil.access %6[0, _] : !stencil.temp<[-1,65]xf64> // CHECK-NEXT: stencil.return %8 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %5 to %3 (<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %3(<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } @@ -200,7 +200,7 @@ builtin.module { %7 = stencil.dyn_access %6[%i, %j] in <[-1, -1]> : <[1, 1]> : !stencil.temp<[-1,65]x[-1,65]xf64> stencil.return %7 : f64 } - stencil.store %5 to %3 (<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> + stencil.store %5 to %3(<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> func.return } } @@ -216,7 +216,7 @@ builtin.module { // CHECK-NEXT: %7 = stencil.dyn_access %6[%i, %j] in <[-1, -1]> : <[1, 1]> : !stencil.temp<[-1,65]x[-1,65]xf64> // CHECK-NEXT: stencil.return %7 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %5 to %3 (<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %3(<[0, 0], [64, 64]>) : !stencil.temp<[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/stencil/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/stencil/ops.mlir index eb973b5702..eb30f4b01a 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/stencil/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/stencil/ops.mlir @@ -48,7 +48,7 @@ func.func @dyn_access(%arg0 : !stencil.field, %arg1 : !stencil.field< %43 = stencil.store_result %42 : !stencil.result stencil.return %8, %13, %18, %23, %28, %33, %38, %43 unroll <[1, 8, 1]> : !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result } - stencil.store %3 to %1 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> + stencil.store %3 to %1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> func.return } @@ -100,7 +100,7 @@ func.func @dyn_access(%arg0 : !stencil.field, %arg1 : !stencil.field< // CHECK-NEXT: %43 = stencil.store_result %42 : !stencil.result // CHECK-NEXT: stencil.return %8, %13, %18, %23, %28, %33, %38, %43 unroll <[1, 8, 1]> : !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %3 to %1 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> +// CHECK-NEXT: stencil.store %3 to %1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } diff --git a/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir b/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir index 6829ad4440..3e124d9860 100644 --- a/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir +++ b/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir @@ -29,7 +29,7 @@ builtin.module { %25 = arith.mulf %22, %24 : tensor<510xf32> csl_stencil.yield %25 : tensor<510xf32> }) - stencil.store %2 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> func.return } } @@ -87,7 +87,7 @@ builtin.module { // CHECK-NEXT: %60 = arith.mulf %57, %59 : tensor<510xf32> // CHECK-NEXT: csl_stencil.yield %60 : tensor<510xf32> // CHECK-NEXT: }) -// CHECK-NEXT: stencil.store %37 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %37 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () diff --git a/tests/filecheck/transforms/distribute-stencil.mlir b/tests/filecheck/transforms/distribute-stencil.mlir index 6719bcded6..467b97447d 100644 --- a/tests/filecheck/transforms/distribute-stencil.mlir +++ b/tests/filecheck/transforms/distribute-stencil.mlir @@ -19,8 +19,8 @@ builtin.module { %46 = arith.addf %45, %44 : f64 stencil.return %46, %45 : f64, f64 } - stencil.store %34 to %28 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> - stencil.store %35 to %29 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> + stencil.store %34 to %28(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> + stencil.store %35 to %29(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> func.return } @@ -41,8 +41,8 @@ builtin.module { // CHECK-NEXT: %16 = arith.addf %15, %14 : f64 // CHECK-NEXT: stencil.return %16, %15 : f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %4 to %1 (<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: stencil.store %5 to %2 (<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %4 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -73,7 +73,7 @@ builtin.module { %xyz = arith.addi %xy, %z : index stencil.return %xyz : index } - stencil.store %92 to %91 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> + stencil.store %92 to %91(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> func.return } @@ -86,7 +86,7 @@ builtin.module { // CHECK-NEXT: %xyz = arith.addi %xy, %z : index // CHECK-NEXT: stencil.return %xyz : index // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %1 to %0 (<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> +// CHECK-NEXT: stencil.store %1 to %0(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -99,7 +99,7 @@ builtin.module { %xyz_1 = arith.addi %xy_1, %z_1 : index stencil.return %xyz_1 : index } - stencil.store %94 to %93 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> + stencil.store %94 to %93(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> func.return } @@ -112,7 +112,7 @@ builtin.module { // CHECK-NEXT: %xyz = arith.addi %xy, %z : index // CHECK-NEXT: stencil.return %xyz : index // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %1 to %0 (<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> +// CHECK-NEXT: stencil.store %1 to %0(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xindex> to !stencil.field<[0,64]x[0,64]x[0,64]xindex> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -151,8 +151,8 @@ func.func @if_lowering(%arg0_1 : f64, %b0 : !stencil.field<[0,8]x[0,8]x[0,8]xf64 %107 = stencil.store_result %104 : !stencil.result stencil.return %103, %107 : !stencil.result, !stencil.result } - stencil.store %101 to %b0 (<[0, 0, 0], [8, 8, 8]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> - stencil.store %102 to %b1 (<[0, 0, 0], [8, 8, 8]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> + stencil.store %101 to %b0(<[0, 0, 0], [8, 8, 8]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> + stencil.store %102 to %b1(<[0, 0, 0], [8, 8, 8]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> func.return } @@ -169,11 +169,11 @@ func.func @if_lowering(%arg0_1 : f64, %b0 : !stencil.field<[0,8]x[0,8]x[0,8]xf64 // CHECK-NEXT: %6 = stencil.store_result %3 : !stencil.result // CHECK-NEXT: stencil.return %2, %6 : !stencil.result, !stencil.result // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %0 to %b0 (<[0, 0, 0], [4, 4, 4]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> -// CHECK-NEXT: stencil.store %1 to %b1 (<[0, 0, 0], [4, 4, 4]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> +// CHECK-NEXT: stencil.store %0 to %b0(<[0, 0, 0], [4, 4, 4]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> +// CHECK-NEXT: stencil.store %1 to %b1(<[0, 0, 0], [4, 4, 4]>) : !stencil.temp<[0,8]x[0,8]x[0,8]xf64> to !stencil.field<[0,8]x[0,8]x[0,8]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } - + } // CHECK-NEXT: } diff --git a/tests/filecheck/transforms/stencil-shape-inference.mlir b/tests/filecheck/transforms/stencil-shape-inference.mlir index cf48430c51..9791b02ac5 100644 --- a/tests/filecheck/transforms/stencil-shape-inference.mlir +++ b/tests/filecheck/transforms/stencil-shape-inference.mlir @@ -10,7 +10,7 @@ builtin.module { %o = arith.addf %l, %r : f64 stencil.return %o : f64 } - stencil.store %tout to %out (<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> + stencil.store %tout to %out(<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> func.return } } @@ -24,7 +24,7 @@ builtin.module { // CHECK-NEXT: %o = arith.addf %l, %r : f64 // CHECK-NEXT: stencil.return %o : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %tout to %out (<[0], [64]>) : !stencil.temp<[0,64]xf64> to !stencil.field<[-4,68]xf64> +// CHECK-NEXT: stencil.store %tout to %out(<[0], [64]>) : !stencil.temp<[0,64]xf64> to !stencil.field<[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -49,7 +49,7 @@ builtin.module { %16 = arith.addf %15, %14 : f64 stencil.return %16 : f64 } - stencil.store %5 to %3 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> + stencil.store %5 to %3(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> func.return } } @@ -73,7 +73,7 @@ builtin.module { // CHECK-NEXT: %16 = arith.addf %15, %14 : f64 // CHECK-NEXT: stencil.return %16 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %5 to %3 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %3(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -98,7 +98,7 @@ builtin.module { %16 = arith.addf %15, %14 : f64 stencil.return %16 : f64 } - stencil.store %5 to %3 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> + stencil.store %5 to %3(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> func.return } } @@ -118,7 +118,7 @@ builtin.module { %6 = arith.addf %4, %5 : f64 stencil.return %6 : f64 } - stencil.store %3 to %2 (<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + stencil.store %3 to %2(<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> func.return } } @@ -130,7 +130,7 @@ builtin.module { // CHECK-NEXT: %6 = arith.addf %4, %5 : f64 // CHECK-NEXT: stencil.return %6 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %3 to %2 (<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: stencil.store %3 to %2(<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -148,7 +148,7 @@ builtin.module { %9 = stencil.access %8[0] : !stencil.temp stencil.return %9 : f64 } - stencil.store %7 to %1 (<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> + stencil.store %7 to %1(<[0], [64]>) : !stencil.temp to !stencil.field<[-4,68]xf64> func.return } } @@ -164,7 +164,7 @@ builtin.module { // CHECK-NEXT: %7 = stencil.access %6[0] : !stencil.temp<[0,64]xf64> // CHECK-NEXT: stencil.return %7 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %5 to %1 (<[0], [64]>) : !stencil.temp<[0,64]xf64> to !stencil.field<[-4,68]xf64> +// CHECK-NEXT: stencil.store %5 to %1(<[0], [64]>) : !stencil.temp<[0,64]xf64> to !stencil.field<[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -180,7 +180,7 @@ func.func @dyn_access(%arg0 : !stencil.field, %arg1 : !stencil.field< %5 = stencil.store_result %4 : !stencil.result stencil.return %5 : !stencil.result } - stencil.store %3 to %1 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> + stencil.store %3 to %1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> func.return } @@ -194,7 +194,7 @@ func.func @dyn_access(%arg0 : !stencil.field, %arg1 : !stencil.field< // CHECK-NEXT: %5 = stencil.store_result %4 : !stencil.result // CHECK-NEXT: stencil.return %5 : !stencil.result // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %3 to %1 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> +// CHECK-NEXT: stencil.store %3 to %1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -215,7 +215,7 @@ func.func @combine(%arg0 : !stencil.field, %arg1 : !stencil.field } %9 = stencil.combine 0 at 32 lower = (%3 : !stencil.temp) upper = (%6 : !stencil.temp) : !stencil.temp - stencil.store %9 to %1 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> + stencil.store %9 to %1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> func.return } @@ -234,7 +234,7 @@ func.func @combine(%arg0 : !stencil.field, %arg1 : !stencil.field // CHECK-NEXT: } // CHECK-NEXT: %5 = stencil.combine 0 at 32 lower = (%3 : !stencil.temp<[0,32]x[0,64]x[0,60]xf64>) upper = (%4 : !stencil.temp<[32,64]x[0,64]x[0,60]xf64>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> -// CHECK-NEXT: stencil.store %5 to %1 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> +// CHECK-NEXT: stencil.store %5 to %1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -278,9 +278,9 @@ func.func @buffer(%arg0 : !stencil.field, %arg1 : !stencil.field stencil.return %23 : !stencil.result } - stencil.store %15 to %0 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> - stencil.store %18 to %1 (<[0, 0, 0], [16, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> - stencil.store %21 to %2 (<[48, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> + stencil.store %15 to %0(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> + stencil.store %18 to %1(<[0, 0, 0], [16, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> + stencil.store %21 to %2(<[48, 0, 0], [64, 64, 60]>) : !stencil.temp to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> func.return } @@ -322,8 +322,8 @@ func.func @buffer(%arg0 : !stencil.field, %arg1 : !stencil.field // CHECK-NEXT: stencil.return %16 : !stencil.result // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %12 to %0 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> -// CHECK-NEXT: stencil.store %13 to %1 (<[0, 0, 0], [16, 64, 60]>) : !stencil.temp<[0,16]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> -// CHECK-NEXT: stencil.store %14 to %2 (<[48, 0, 0], [64, 64, 60]>) : !stencil.temp<[48,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> +// CHECK-NEXT: stencil.store %12 to %0(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> +// CHECK-NEXT: stencil.store %13 to %1(<[0, 0, 0], [16, 64, 60]>) : !stencil.temp<[0,16]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> +// CHECK-NEXT: stencil.store %14 to %2(<[48, 0, 0], [64, 64, 60]>) : !stencil.temp<[48,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } diff --git a/tests/filecheck/transforms/stencil-storage-materialization.mlir b/tests/filecheck/transforms/stencil-storage-materialization.mlir index 9dfcb655ab..f2a27c9e7d 100644 --- a/tests/filecheck/transforms/stencil-storage-materialization.mlir +++ b/tests/filecheck/transforms/stencil-storage-materialization.mlir @@ -9,7 +9,7 @@ builtin.module{ %v = stencil.access %inb[-1] : !stencil.temp stencil.return %v : f64 } - stencil.store %outt to %out (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> + stencil.store %outt to %out(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> func.return } @@ -19,7 +19,7 @@ builtin.module{ // CHECK-NEXT: %v = stencil.access %inb[-1] : !stencil.temp // CHECK-NEXT: stencil.return %v : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %outt to %out (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> +// CHECK-NEXT: stencil.store %outt to %out(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -35,7 +35,7 @@ builtin.module{ %v = stencil.access %midb[-1] : !stencil.temp stencil.return %v : f64 } - stencil.store %outt to %out (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> + stencil.store %outt to %out(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> func.return } @@ -50,7 +50,7 @@ builtin.module{ // CHECK-NEXT: %v = stencil.access %midb[-1] : !stencil.temp // CHECK-NEXT: stencil.return %v : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %outt to %out (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> +// CHECK-NEXT: stencil.store %outt to %out(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -62,12 +62,12 @@ builtin.module{ %v = stencil.access %inb[-1] : !stencil.temp stencil.return %v : f64 } - stencil.store %midt to %midout (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> + stencil.store %midt to %midout(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> %outt = stencil.apply(%midb = %midt : !stencil.temp) -> (!stencil.temp) { %v = stencil.access %midb[-1] : !stencil.temp stencil.return %v : f64 } - stencil.store %outt to %out (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> + stencil.store %outt to %out(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> func.return } @@ -77,12 +77,12 @@ builtin.module{ // CHECK-NEXT: %v = stencil.access %inb[-1] : !stencil.temp // CHECK-NEXT: stencil.return %v : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %midt to %midout (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> +// CHECK-NEXT: stencil.store %midt to %midout(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> // CHECK-NEXT: %outt = stencil.apply(%midb = %midt : !stencil.temp) -> (!stencil.temp) { // CHECK-NEXT: %v = stencil.access %midb[-1] : !stencil.temp // CHECK-NEXT: stencil.return %v : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %outt to %out (<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> +// CHECK-NEXT: stencil.store %outt to %out(<[0], [68]>) : !stencil.temp to !stencil.field<[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -103,7 +103,7 @@ builtin.module{ %11 = arith.addf %9, %10 : f64 stencil.return %11 : f64 } - stencil.store %7 to %1 (<[1, 2], [65, 66]>) : !stencil.temp<[1,65]x[2,66]xf64> to !stencil.field<[-3,67]x[-3,67]xf64> + stencil.store %7 to %1(<[1, 2], [65, 66]>) : !stencil.temp<[1,65]x[2,66]xf64> to !stencil.field<[-3,67]x[-3,67]xf64> func.return } @@ -125,7 +125,7 @@ builtin.module{ // CHECK-NEXT: %10 = arith.addf %8, %9 : f64 // CHECK-NEXT: stencil.return %10 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %6 to %1 (<[1, 2], [65, 66]>) : !stencil.temp<[1,65]x[2,66]xf64> to !stencil.field<[-3,67]x[-3,67]xf64> +// CHECK-NEXT: stencil.store %6 to %1(<[1, 2], [65, 66]>) : !stencil.temp<[1,65]x[2,66]xf64> to !stencil.field<[-3,67]x[-3,67]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } diff --git a/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir b/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir index 4913677e12..f08e6e1e16 100644 --- a/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir +++ b/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir @@ -55,7 +55,7 @@ builtin.module { // CHECK-NEXT: %24 = arith.mulf %23, %5 : tensor<510xf32> // CHECK-NEXT: stencil.return %24 : tensor<510xf32> // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %3 to %2 (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %3 to %2(<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -107,9 +107,9 @@ builtin.module { // CHECK-NEXT: %22 = arith.mulf %21, %3 : tensor<510xf32> // CHECK-NEXT: stencil.return %22 : tensor<510xf32> // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %1 to %b (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } } -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } diff --git a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/stencil-to-csl-stencil.mlir index 179217f805..b438945507 100644 --- a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/stencil-to-csl-stencil.mlir @@ -28,7 +28,7 @@ builtin.module { %23 = arith.mulf %20, %22 : tensor<510xf32> stencil.return %23 : tensor<510xf32> } - stencil.store %1 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> func.return } } @@ -62,7 +62,7 @@ builtin.module { // CHECK-NEXT: %25 = arith.mulf %22, %24 : tensor<510xf32> // CHECK-NEXT: csl_stencil.yield %25 : tensor<510xf32> // CHECK-NEXT: }) -// CHECK-NEXT: stencil.store %2 to %b (<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tests/filecheck/transforms/stencil-unroll.mlir b/tests/filecheck/transforms/stencil-unroll.mlir index b6b2cc1dd3..26fac4a1b8 100644 --- a/tests/filecheck/transforms/stencil-unroll.mlir +++ b/tests/filecheck/transforms/stencil-unroll.mlir @@ -7,7 +7,7 @@ %6 = arith.addf %4, %5 : f64 stencil.return %6 : f64 } - stencil.store %3 to %2 (<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + stencil.store %3 to %2(<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> func.return } @@ -32,7 +32,7 @@ // CHECK-NEXT: %20 = arith.addf %4, %19 : f64 // CHECK-NEXT: stencil.return %6, %8, %10, %12, %14, %16, %18, %20 unroll <[1, 8, 1]> : f64, f64, f64, f64, f64, f64, f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %3 to %2 (<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: stencil.store %3 to %2(<[1, 2, 3], [65, 66, 63]>) : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -44,7 +44,7 @@ %12 = stencil.access %11[-1] : !stencil.temp<[-1,68]xf64> stencil.return %12 : f64 } - stencil.store %10 to %outc (<[0], [68]>) : !stencil.temp<[0,68]xf64> to !stencil.field<[0,1024]xf64> + stencil.store %10 to %outc(<[0], [68]>) : !stencil.temp<[0,68]xf64> to !stencil.field<[0,1024]xf64> func.return } @@ -56,7 +56,7 @@ // CHECK-NEXT: %5 = stencil.access %4[-1] : !stencil.temp<[-1,68]xf64> // CHECK-NEXT: stencil.return %5 : f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %3 to %outc (<[0], [68]>) : !stencil.temp<[0,68]xf64> to !stencil.field<[0,1024]xf64> +// CHECK-NEXT: stencil.store %3 to %outc(<[0], [68]>) : !stencil.temp<[0,68]xf64> to !stencil.field<[0,1024]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -79,7 +79,7 @@ %32 = arith.addf %31, %30 : f64 stencil.return %32, %31 : f64, f64 } - stencil.store %20 to %17 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> + stencil.store %20 to %17(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> func.return } @@ -179,7 +179,7 @@ // CHECK-NEXT: %89 = arith.addf %88, %87 : f64 // CHECK-NEXT: stencil.return %19, %18, %29, %28, %39, %38, %49, %48, %59, %58, %69, %68, %79, %78, %89, %88 unroll <[1, 8, 1]> : f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %7 to %4 (<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.store %7 to %4(<[0, 0, 0], [64, 64, 64]>) : !stencil.temp<[0,64]x[0,64]x[0,64]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -195,7 +195,7 @@ func.func @dyn_access(%arg0 : !stencil.field, %arg1 : !stencil.field< %41 = stencil.store_result %40 : !stencil.result stencil.return %41 : !stencil.result } - stencil.store %36 to %34 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> + stencil.store %36 to %34(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> func.return } @@ -246,6 +246,6 @@ func.func @dyn_access(%arg0 : !stencil.field, %arg1 : !stencil.field< // CHECK-NEXT: %43 = stencil.store_result %42 : !stencil.result // CHECK-NEXT: stencil.return %8, %13, %18, %23, %28, %33, %38, %43 unroll <[1, 8, 1]> : !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result, !stencil.result // CHECK-NEXT: } -// CHECK-NEXT: stencil.store %3 to %1 (<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> +// CHECK-NEXT: stencil.store %3 to %1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[0,60]xf64> // CHECK-NEXT: func.return // CHECK-NEXT: } diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 5d23421d14..b8f2d3084b 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -455,6 +455,37 @@ class OptionalPropertyOp(IRDLOperation): check_equivalence(program, generic_program, ctx) +@pytest.mark.parametrize( + "program, generic_program", + [ + ( + "test.optional_property()", + '"test.optional_property"() : () -> ()', + ), + ( + "test.optional_property( prop i32 )", + '"test.optional_property"() <{"prop" = i32}> : () -> ()', + ), + ], +) +def test_optional_property_with_whitespace(program: str, generic_program: str): + """Test the parsing of optional operands""" + + @irdl_op_definition + class OptionalPropertyOp(IRDLOperation): + name = "test.optional_property" + prop = opt_prop_def(Attribute) + + assembly_format = "`(` (` ` `prop` $prop^ ` `)? `)` attr-dict" + + ctx = MLContext() + ctx.load_op(OptionalPropertyOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + check_equivalence(program, generic_program, ctx) + + @pytest.mark.parametrize( "program, generic_program", [ @@ -603,6 +634,7 @@ class TypedAttributeOp(IRDLOperation): "test.punctuation keyword, keyword", ), ("`keyword` ` ` `,` `keyword` attr-dict", "test.punctuation keyword , keyword"), + ("`keyword` `,` `` `keyword` attr-dict", "test.punctuation keyword,keyword"), ( "`keyword` `\\n` `,` `keyword` attr-dict", "test.punctuation keyword\n, keyword", @@ -1480,7 +1512,7 @@ class OptionalGroupOp(IRDLOperation): @pytest.mark.parametrize( "format, error", ( - ("()?", "An optional group cannot be empty"), + ("()?", "An optional group must have a non-whitespace directive"), ("(`keyword`)?", "Every optional group must have an anchor."), ( "($args^ type($rets)^)?", diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index c7f70d6db7..8259a5b1b3 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -822,13 +822,13 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No class WhitespaceDirective(FormatDirective): """ A whitespace directive, with the following format: - whitespace-directive ::= `\n` | ` ` + whitespace-directive ::= `\n` | ` ` | `` This directive is only applied during printing, and has no effect during parsing. The directive will not request any space to be printed after. """ - whitespace: Literal[" ", "\n"] + whitespace: Literal[" ", "\n", ""] """The whitespace that should be printed.""" def parse(self, parser: Parser, state: ParsingState) -> None: @@ -836,7 +836,7 @@ def parse(self, parser: Parser, state: ParsingState) -> None: def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: printer.print(self.whitespace) - state.last_was_punctuation = False + state.last_was_punctuation = self.whitespace == "" state.should_emit_space = False @@ -908,6 +908,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No @dataclass(frozen=True) class OptionalGroupDirective(FormatDirective): anchor: AnchorableDirective + then_whitespace: tuple[WhitespaceDirective, ...] then_first: OptionallyParsableDirective then_elements: tuple[FormatDirective, ...] @@ -944,5 +945,9 @@ def parse(self, parser: Parser, state: ParsingState) -> None: def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if self.anchor.is_present(op): - for element in (self.then_first, *self.then_elements): + for element in ( + *self.then_whitespace, + self.then_first, + *self.then_elements, + ): element.print(printer, state, op) diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 79f7b13997..53aabf17d7 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -487,12 +487,21 @@ def parse_optional_group(self) -> FormatDirective: anchor = then_elements[-1] self.parse_punctuation("?") - if not then_elements: - self.raise_error("An optional group cannot be empty") + # Pull whitespace element of front, as they are not parsed + first_non_whitespace_index = None + for i, x in enumerate(then_elements): + if not isinstance(x, WhitespaceDirective): + first_non_whitespace_index = i + break + + if first_non_whitespace_index is None: + self.raise_error("An optional group must have a non-whitespace directive") if anchor is None: self.raise_error("Every optional group must have an anchor.") # TODO: allow attribute and region variables when implemented. - if not isinstance(then_elements[0], OptionallyParsableDirective): + if not isinstance( + then_elements[first_non_whitespace_index], OptionallyParsableDirective + ): self.raise_error( "First element of an optional group must be optionally parsable." ) @@ -501,15 +510,25 @@ def parse_optional_group(self) -> FormatDirective: "An optional group's anchor must be an achorable directive." ) - return OptionalGroupDirective(anchor, then_elements[0], then_elements[1:]) + return OptionalGroupDirective( + anchor, + cast( + tuple[WhitespaceDirective, ...], + then_elements[:first_non_whitespace_index], + ), + cast( + OptionallyParsableDirective, then_elements[first_non_whitespace_index] + ), + then_elements[first_non_whitespace_index + 1 :], + ) def parse_keyword_or_punctuation(self) -> FormatDirective: """ Parse a keyword or a punctuation directive, with the following format: keyword-or-punctuation-directive ::= `\\`` (bare-ident | punctuation) `\\`` """ - self.parse_characters("`") start_token = self._current_token + self.parse_characters("`") # New line case if self.parse_optional_keyword("\\"): @@ -518,16 +537,16 @@ def parse_keyword_or_punctuation(self) -> FormatDirective: return WhitespaceDirective("\n") # Space case + end_token = self._current_token if self.parse_optional_characters("`"): - end_token = self._current_token whitespace = self.lexer.input.content[ start_token.span.end : end_token.span.start ] - if whitespace != " ": + if whitespace != " " and whitespace != "": self.raise_error( - "unexpected whitespace in directive, only ` ` whitespace is allowed" + "unexpected whitespace in directive, only ` ` or `` whitespace is allowed" ) - return WhitespaceDirective(" ") + return WhitespaceDirective(whitespace) # Punctuation case if self._current_token.kind.is_punctuation(): From 44230382f9c5e58b5fe07546a46e26dc982d4bd4 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 16 Aug 2024 15:12:19 +0100 Subject: [PATCH 2/5] misc: Fix typo achorable -> anchorable (#3048) Annoyed me enough while looking at other stuff to fix it --- tests/irdl/test_declarative_assembly_format.py | 5 ++++- xdsl/irdl/declarative_assembly_format_parser.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index b8f2d3084b..c3a7949494 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -1518,7 +1518,10 @@ class OptionalGroupOp(IRDLOperation): "($args^ type($rets)^)?", "An optional group can only have one anchor.", ), - ("(`keyword`^)?", "An optional group's anchor must be an achorable directive."), + ( + "(`keyword`^)?", + "An optional group's anchor must be an anchorable directive.", + ), ( "($mandatory_arg^)?", "First element of an optional group must be optionally parsable.", diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 53aabf17d7..87fc3064cf 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -507,7 +507,7 @@ def parse_optional_group(self) -> FormatDirective: ) if not isinstance(anchor, AnchorableDirective): self.raise_error( - "An optional group's anchor must be an achorable directive." + "An optional group's anchor must be an anchorable directive." ) return OptionalGroupDirective( From efdc7cfc01ca0e5006a852020b32603fcb03f3a2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 17 Aug 2024 10:35:36 +0100 Subject: [PATCH 3/5] pip prod(deps): bump ruff from 0.5.7 to 0.6.0 (#3043) Bumps [ruff](https://github.com/astral-sh/ruff) from 0.5.7 to 0.6.0.
Release notes

Sourced from ruff's releases.

0.6.0

Release Notes

Check out the blog post for a migration guide and overview of the changes!

Breaking changes

See also, the "Remapped rules" section which may result in disabled rules.

  • Lint and format Jupyter Notebook by default (#12878).
  • Detect imports in src layouts by default for isort rules (#12848)
  • The pytest rules PT001 and PT023 now default to omitting the decorator parentheses when there are no arguments (#12838).

Deprecations

The following rules are now deprecated:

Remapped rules

The following rules have been remapped to new rule codes:

Stabilization

The following rules have been stabilized and are no longer in preview:

The following behaviors have been stabilized:

... (truncated)

Changelog

Sourced from ruff's changelog.

0.6.0

Check out the blog post for a migration guide and overview of the changes!

Breaking changes

See also, the "Remapped rules" section which may result in disabled rules.

  • Lint and format Jupyter Notebook by default (#12878).
  • Detect imports in src layouts by default for isort rules (#12848)
  • The pytest rules PT001 and PT023 now default to omitting the decorator parentheses when there are no arguments (#12838).

Deprecations

The following rules are now deprecated:

Remapped rules

The following rules have been remapped to new rule codes:

Stabilization

The following rules have been stabilized and are no longer in preview:

The following behaviors have been stabilized:

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.5.7&new-version=0.6.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b3011bc609..9d399d9bc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dev = [ "lit<19.0.0", "marimo==0.7.20", "pre-commit==3.8.0", - "ruff==0.5.7", + "ruff==0.6.0", "asv<0.7", "nbconvert>=7.7.2,<8.0.0", "textual-dev==1.5.1", From 7473b8763b8408443b254e8f5969c758d4d3e611 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sat, 17 Aug 2024 10:35:53 +0100 Subject: [PATCH 4/5] dependencies: update ruff-pre-commit to v0.6.0 (#3039) bleep bloop --- .pre-commit-config.yaml | 2 +- tests/interactive/test_app.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a304b925eb..8079e00950 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: hooks: - id: check-yaml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.7 + rev: v0.6.0 hooks: - id: ruff types_or: [ python, pyi, jupyter ] diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index eca2aafadf..c0995abe67 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -28,7 +28,7 @@ from xdsl.utils.parse_pipeline import PipelinePassSpec, parse_pipeline -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_inputs(): """Test different inputs produce desired result.""" async with InputApp().run_test() as pilot: @@ -94,7 +94,7 @@ async def test_inputs(): assert app.current_module.is_structurally_equivalent(expected_module) -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_buttons(): """Test pressing keys has the desired result.""" async with InputApp().run_test() as pilot: @@ -275,7 +275,7 @@ async def test_buttons(): assert app.condense_mode is False -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_rewrites(): """Test rewrite application has the desired result.""" async with InputApp().run_test() as pilot: @@ -342,7 +342,7 @@ async def test_rewrites(): ) -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_passes(): """Test pass application has the desired result.""" async with InputApp().run_test() as pilot: @@ -434,7 +434,7 @@ async def test_passes(): assert app.current_module.is_structurally_equivalent(expected_module) -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_argument_pass_screen(): """Test that clicking on a pass that requires passes opens a screen to specify them.""" async with InputApp().run_test() as pilot: From b01afed9567a09a9a7000ca439a1e4121a3bcc26 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Sat, 17 Aug 2024 12:31:34 +0100 Subject: [PATCH 5/5] transformations: stencil bufferization polish (#3001) Step back on the idea of allocating all apply results to then fold away and instead just fold stores into apply. It's cheaper, simpler, and so far works just as well! Also rely on the new side-effect value information to improve side-effect analysis. Both changes allow to bufferize a full open earth compiler example now! Minor test changes, the bufferization used to blindly allocate buffers for unused results, now it scraps them. I guess it's debatable if that should be its role? Note that it does so by reusing the canonicalization pattern, no logic is actually duplicated. --------- Co-authored-by: Sasha Lopoukhine --- .../stencil/oec-kernels/fvtp2d_qi.mlir | 115 +++++++++ .../transforms/stencil-bufferize.mlir | 75 ++---- xdsl/dialects/stencil.py | 43 +++- .../canonicalization_patterns/stencil.py | 10 +- xdsl/transforms/stencil_bufferize.py | 223 +++++++++--------- 5 files changed, 300 insertions(+), 166 deletions(-) diff --git a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir index b94016cd75..7e2b79a268 100644 --- a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir +++ b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir @@ -1,6 +1,7 @@ // RUN: XDSL_ROUNDTRIP // RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference | filecheck %s --check-prefix SHAPE // RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference,convert-stencil-to-ll-mlir | filecheck %s --check-prefix MLIR +// RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference,stencil-bufferize | filecheck %s --check-prefix BUFF func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field, %arg2: !stencil.field, %arg3: !stencil.field, %arg4: !stencil.field, %arg5: !stencil.field, %arg6: !stencil.field) attributes {stencil.program} { %0 = stencil.cast %arg0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> @@ -547,3 +548,117 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field> // MLIR-NEXT: func.return // MLIR-NEXT: } + +// BUFF: func.func @fvtp2d_qi(%arg0 : !stencil.field, %arg1 : !stencil.field, %arg2 : !stencil.field, %arg3 : !stencil.field, %arg4 : !stencil.field, %arg5 : !stencil.field, %arg6 : !stencil.field) attributes {"stencil.program"}{ +// BUFF-NEXT: %0 = stencil.alloc : !stencil.field<[0,64]x[0,65]x[0,64]xf64> +// BUFF-NEXT: %1 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %2 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %3 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %4 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %5 = stencil.alloc : !stencil.field<[0,64]x[-1,66]x[0,64]xf64> +// BUFF-NEXT: %6 = stencil.cast %arg0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %7 = stencil.cast %arg1 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %8 = stencil.cast %arg2 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %9 = stencil.cast %arg3 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %10 = stencil.cast %arg4 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %11 = stencil.cast %arg5 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %12 = stencil.cast %arg6 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%5 : !stencil.field<[0,64]x[-1,66]x[0,64]xf64>) { +// BUFF-NEXT: %cst = arith.constant 1.000000e+00 : f64 +// BUFF-NEXT: %cst_1 = arith.constant 7.000000e+00 : f64 +// BUFF-NEXT: %cst_2 = arith.constant 1.200000e+01 : f64 +// BUFF-NEXT: %13 = arith.divf %cst_1, %cst_2 : f64 +// BUFF-NEXT: %14 = arith.divf %cst, %cst_2 : f64 +// BUFF-NEXT: %15 = stencil.access %arg7[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %16 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %17 = arith.addf %15, %16 : f64 +// BUFF-NEXT: %18 = stencil.access %arg7[0, -2, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %19 = stencil.access %arg7[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %20 = arith.addf %18, %19 : f64 +// BUFF-NEXT: %21 = arith.mulf %13, %17 : f64 +// BUFF-NEXT: %22 = arith.mulf %14, %20 : f64 +// BUFF-NEXT: %23 = arith.addf %21, %22 : f64 +// BUFF-NEXT: %24 = stencil.store_result %23 : !stencil.result +// BUFF-NEXT: stencil.return %24 : !stencil.result +// BUFF-NEXT: } to <[0, -1, 0], [64, 66, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %5 : !stencil.field<[0,64]x[-1,66]x[0,64]xf64>) outs (%4 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %3 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %2 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %1 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>) { +// BUFF-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// BUFF-NEXT: %cst_1 = arith.constant 1.000000e+00 : f64 +// BUFF-NEXT: %13 = stencil.access %arg8[0, 0, 0] : !stencil.field<[0,64]x[-1,66]x[0,64]xf64> +// BUFF-NEXT: %14 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %15 = arith.subf %13, %14 : f64 +// BUFF-NEXT: %16 = stencil.access %arg8[0, 1, 0] : !stencil.field<[0,64]x[-1,66]x[0,64]xf64> +// BUFF-NEXT: %17 = arith.subf %16, %14 : f64 +// BUFF-NEXT: %18 = arith.addf %15, %17 : f64 +// BUFF-NEXT: %19 = arith.mulf %15, %17 : f64 +// BUFF-NEXT: %20 = arith.cmpf olt, %19, %cst : f64 +// BUFF-NEXT: %21 = arith.select %20, %cst_1, %cst : f64 +// BUFF-NEXT: %22 = stencil.store_result %15 : !stencil.result +// BUFF-NEXT: %23 = stencil.store_result %17 : !stencil.result +// BUFF-NEXT: %24 = stencil.store_result %18 : !stencil.result +// BUFF-NEXT: %25 = stencil.store_result %21 : !stencil.result +// BUFF-NEXT: stencil.return %22, %23, %24, %25 : !stencil.result, !stencil.result, !stencil.result, !stencil.result +// BUFF-NEXT: } to <[0, -1, 0], [64, 65, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %7 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg9 = %4 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %arg10 = %3 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %arg11 = %2 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %arg12 = %1 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>) outs (%12 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// BUFF-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// BUFF-NEXT: %cst_1 = arith.constant 1.000000e+00 : f64 +// BUFF-NEXT: %13 = stencil.access %arg12[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %14 = arith.cmpf oeq, %13, %cst : f64 +// BUFF-NEXT: %15 = arith.select %14, %cst_1, %cst : f64 +// BUFF-NEXT: %16 = stencil.access %arg12[0, 0, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %17 = arith.mulf %16, %15 : f64 +// BUFF-NEXT: %18 = arith.addf %13, %17 : f64 +// BUFF-NEXT: %19 = stencil.access %arg8[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %20 = arith.cmpf ogt, %19, %cst : f64 +// BUFF-NEXT: %21 = "scf.if"(%20) ({ +// BUFF-NEXT: %22 = stencil.access %arg10[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %23 = stencil.access %arg11[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %24 = arith.mulf %19, %23 : f64 +// BUFF-NEXT: %25 = arith.subf %22, %24 : f64 +// BUFF-NEXT: %26 = arith.subf %cst_1, %19 : f64 +// BUFF-NEXT: %27 = arith.mulf %26, %25 : f64 +// BUFF-NEXT: scf.yield %27 : f64 +// BUFF-NEXT: }, { +// BUFF-NEXT: %28 = stencil.access %arg9[0, 0, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %29 = stencil.access %arg11[0, 0, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %30 = arith.mulf %19, %29 : f64 +// BUFF-NEXT: %31 = arith.addf %28, %30 : f64 +// BUFF-NEXT: %32 = arith.addf %cst_1, %19 : f64 +// BUFF-NEXT: %33 = arith.mulf %32, %31 : f64 +// BUFF-NEXT: scf.yield %33 : f64 +// BUFF-NEXT: }) : (i1) -> f64 +// BUFF-NEXT: %34 = arith.mulf %21, %18 : f64 +// BUFF-NEXT: %35 = "scf.if"(%20) ({ +// BUFF-NEXT: %36 = stencil.access %arg7[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %37 = arith.addf %36, %34 : f64 +// BUFF-NEXT: scf.yield %37 : f64 +// BUFF-NEXT: }, { +// BUFF-NEXT: %38 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %39 = arith.addf %38, %34 : f64 +// BUFF-NEXT: scf.yield %39 : f64 +// BUFF-NEXT: }) : (i1) -> f64 +// BUFF-NEXT: %40 = stencil.store_result %35 : !stencil.result +// BUFF-NEXT: stencil.return %40 : !stencil.result +// BUFF-NEXT: } to <[0, 0, 0], [64, 65, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %9 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %12 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%0 : !stencil.field<[0,64]x[0,65]x[0,64]xf64>) { +// BUFF-NEXT: %13 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %14 = stencil.access %arg8[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %15 = arith.mulf %13, %14 : f64 +// BUFF-NEXT: %16 = stencil.store_result %15 : !stencil.result +// BUFF-NEXT: stencil.return %16 : !stencil.result +// BUFF-NEXT: } to <[0, 0, 0], [64, 65, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %10 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg9 = %0 : !stencil.field<[0,64]x[0,65]x[0,64]xf64>, %arg10 = %8 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%11 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// BUFF-NEXT: %13 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %14 = stencil.access %arg8[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %15 = arith.mulf %13, %14 : f64 +// BUFF-NEXT: %16 = stencil.access %arg9[0, 0, 0] : !stencil.field<[0,64]x[0,65]x[0,64]xf64> +// BUFF-NEXT: %17 = stencil.access %arg9[0, 1, 0] : !stencil.field<[0,64]x[0,65]x[0,64]xf64> +// BUFF-NEXT: %18 = arith.subf %16, %17 : f64 +// BUFF-NEXT: %19 = arith.addf %15, %18 : f64 +// BUFF-NEXT: %20 = stencil.access %arg10[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %21 = arith.divf %19, %20 : f64 +// BUFF-NEXT: %22 = stencil.store_result %21 : !stencil.result +// BUFF-NEXT: stencil.return %22 : !stencil.result +// BUFF-NEXT: } to <[0, 0, 0], [64, 64, 64]> +// BUFF-NEXT: func.return +// BUFF-NEXT: } diff --git a/tests/filecheck/transforms/stencil-bufferize.mlir b/tests/filecheck/transforms/stencil-bufferize.mlir index 9c193299af..4a7287ba03 100644 --- a/tests/filecheck/transforms/stencil-bufferize.mlir +++ b/tests/filecheck/transforms/stencil-bufferize.mlir @@ -73,43 +73,43 @@ func.func @copy_1d(%0 : !stencil.field, %out : !stencil.field) { // CHECK-NEXT: func.return // CHECK-NEXT: } -func.func @copy_2d(%0 : !stencil.field) { +func.func @copy_2d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]xf64>) { %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]xf64> %2 = stencil.load %1 : !stencil.field<[-4,68]x[-4,68]xf64> -> !stencil.temp<[-1,64]x[0,68]xf64> %3 = stencil.apply(%4 = %2 : !stencil.temp<[-1,64]x[0,68]xf64>) -> (!stencil.temp<[0,64]x[0,68]xf64>) { %5 = stencil.access %4[-1, 0] : !stencil.temp<[-1,64]x[0,68]xf64> stencil.return %5 : f64 } + stencil.store %3 to %out (<[0, 0], [64, 68]>) : !stencil.temp<[0,64]x[0,68]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> func.return } -// CHECK: func.func @copy_2d(%0 : !stencil.field) { +// CHECK: func.func @copy_2d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]xf64>) { // CHECK-NEXT: %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %2 = stencil.alloc : !stencil.field<[0,64]x[0,68]xf64> -// CHECK-NEXT: stencil.apply(%3 = %1 : !stencil.field<[-4,68]x[-4,68]xf64>) outs (%2 : !stencil.field<[0,64]x[0,68]xf64>) { -// CHECK-NEXT: %4 = stencil.access %3[-1, 0] : !stencil.field<[-4,68]x[-4,68]xf64> -// CHECK-NEXT: stencil.return %4 : f64 +// CHECK-NEXT: stencil.apply(%2 = %1 : !stencil.field<[-4,68]x[-4,68]xf64>) outs (%out : !stencil.field<[-4,68]x[-4,68]xf64>) { +// CHECK-NEXT: %3 = stencil.access %2[-1, 0] : !stencil.field<[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.return %3 : f64 // CHECK-NEXT: } to <[0, 0], [64, 68]> // CHECK-NEXT: func.return // CHECK-NEXT: } -func.func @copy_3d(%0 : !stencil.field) { +func.func @copy_3d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> %2 = stencil.load %1 : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> -> !stencil.temp<[-1,64]x[0,64]x[0,69]xf64> %3 = stencil.apply(%4 = %2 : !stencil.temp<[-1,64]x[0,64]x[0,69]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,68]xf64>) { %5 = stencil.access %4[-1, 0, 1] : !stencil.temp<[-1,64]x[0,64]x[0,69]xf64> stencil.return %5 : f64 } + stencil.store %3 to %out (<[0, 0, 0], [64, 64, 68]>) : !stencil.temp<[0,64]x[0,64]x[0,68]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> func.return } -// CHECK: func.func @copy_3d(%0 : !stencil.field) { +// CHECK: func.func @copy_3d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { // CHECK-NEXT: %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> -// CHECK-NEXT: %2 = stencil.alloc : !stencil.field<[0,64]x[0,64]x[0,68]xf64> -// CHECK-NEXT: stencil.apply(%3 = %1 : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64>) outs (%2 : !stencil.field<[0,64]x[0,64]x[0,68]xf64>) { -// CHECK-NEXT: %4 = stencil.access %3[-1, 0, 1] : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> -// CHECK-NEXT: stencil.return %4 : f64 +// CHECK-NEXT: stencil.apply(%2 = %1 : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64>) outs (%out : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// CHECK-NEXT: %3 = stencil.access %2[-1, 0, 1] : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> +// CHECK-NEXT: stencil.return %3 : f64 // CHECK-NEXT: } to <[0, 0, 0], [64, 64, 68]> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -140,20 +140,19 @@ func.func @offsets(%0 : !stencil.field, %1 : !stencil.field, %1 : !stencil.field, %2 : !stencil.field) { // CHECK-NEXT: %3 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: %4 = stencil.cast %1 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %5 = stencil.alloc : !stencil.field<[0,64]x[0,64]x[0,64]xf64> -// CHECK-NEXT: stencil.apply(%6 = %3 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%4 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %5 : !stencil.field<[0,64]x[0,64]x[0,64]xf64>) { -// CHECK-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %9 = stencil.access %6[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %10 = stencil.access %6[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %11 = stencil.access %6[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %12 = arith.addf %7, %8 : f64 -// CHECK-NEXT: %13 = arith.addf %9, %10 : f64 -// CHECK-NEXT: %14 = arith.addf %12, %13 : f64 +// CHECK-NEXT: stencil.apply(%5 = %3 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%4 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// CHECK-NEXT: %6 = stencil.access %5[-1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %7 = stencil.access %5[1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %8 = stencil.access %5[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %9 = stencil.access %5[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %10 = stencil.access %5[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %11 = arith.addf %6, %7 : f64 +// CHECK-NEXT: %12 = arith.addf %8, %9 : f64 +// CHECK-NEXT: %13 = arith.addf %11, %12 : f64 // CHECK-NEXT: %cst = arith.constant -4.000000e+00 : f64 -// CHECK-NEXT: %15 = arith.mulf %11, %cst : f64 -// CHECK-NEXT: %16 = arith.addf %15, %14 : f64 -// CHECK-NEXT: stencil.return %16, %15 : f64, f64 +// CHECK-NEXT: %14 = arith.mulf %10, %cst : f64 +// CHECK-NEXT: %15 = arith.addf %14, %13 : f64 +// CHECK-NEXT: stencil.return %15 : f64 // CHECK-NEXT: } to <[0, 0, 0], [64, 64, 64]> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -342,31 +341,9 @@ func.func @stencil_init_index_offset(%0 : !stencil.field<[0,64]x[0,64]x[0,64]xin // CHECK-NEXT: func.return // CHECK-NEXT: } -func.func @store_result_lowering(%arg0 : f64) { - %0, %1 = stencil.apply(%arg1 = %arg0 : f64) -> (!stencil.temp<[0,7]x[0,7]x[0,7]xf64>, !stencil.temp<[0,7]x[0,7]x[0,7]xf64>) { - %2 = stencil.store_result %arg1 : !stencil.result - %3 = stencil.store_result %arg1 : !stencil.result - stencil.return %2, %3 : !stencil.result, !stencil.result - } - %2 = stencil.buffer %1 : !stencil.temp<[0,7]x[0,7]x[0,7]xf64> -> !stencil.temp<[0,7]x[0,7]x[0,7]xf64> - %3 = stencil.buffer %0 : !stencil.temp<[0,7]x[0,7]x[0,7]xf64> -> !stencil.temp<[0,7]x[0,7]x[0,7]xf64> - func.return -} - -// CHECK: func.func @store_result_lowering(%arg0 : f64) { -// CHECK-NEXT: %0 = stencil.alloc : !stencil.field<[0,7]x[0,7]x[0,7]xf64> -// CHECK-NEXT: %1 = stencil.alloc : !stencil.field<[0,7]x[0,7]x[0,7]xf64> -// CHECK-NEXT: stencil.apply(%arg1 = %arg0 : f64) outs (%0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) { -// CHECK-NEXT: %2 = stencil.store_result %arg1 : !stencil.result -// CHECK-NEXT: %3 = stencil.store_result %arg1 : !stencil.result -// CHECK-NEXT: stencil.return %2, %3 : !stencil.result, !stencil.result -// CHECK-NEXT: } to <[0, 0, 0], [7, 7, 7]> -// CHECK-NEXT: func.return -// CHECK-NEXT: } - func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ %0, %1 = stencil.apply(%arg1 = %arg0 : f64) -> (!stencil.temp<[0,7]x[0,7]x[0,7]xf64>, !stencil.temp<[0,7]x[0,7]x[0,7]xf64>) { - %true = "test.op"() : () -> i1 + %true = "test.pureop"() : () -> i1 %2, %3 = "scf.if"(%true) ({ %4 = stencil.store_result %arg1 : !stencil.result scf.yield %4, %arg1 : !stencil.result, f64 @@ -384,7 +361,7 @@ func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, // CHECK: func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ // CHECK-NEXT: stencil.apply(%arg1 = %arg0 : f64) outs (%b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) { -// CHECK-NEXT: %true = "test.op"() : () -> i1 +// CHECK-NEXT: %true = "test.pureop"() : () -> i1 // CHECK-NEXT: %0, %1 = "scf.if"(%true) ({ // CHECK-NEXT: %2 = stencil.store_result %arg1 : !stencil.result // CHECK-NEXT: scf.yield %2, %arg1 : !stencil.result, f64 diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index e70b5120f0..ce57a85c77 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -66,7 +66,6 @@ IsTerminator, MemoryEffect, MemoryEffectKind, - MemoryReadEffect, NoMemoryEffect, Pure, RecursiveMemoryEffect, @@ -426,10 +425,11 @@ class ApplyMemoryEffect(RecursiveMemoryEffect): def get_effects(cls, op: Operation): effects = super().get_effects(op) if effects is not None: - if len(cast(ApplyOp, op).dest) > 0: - effects.add(EffectInstance(MemoryEffectKind.WRITE)) - if any(isinstance(o.type, FieldType) for o in op.operands): - effects.add(EffectInstance(MemoryEffectKind.READ)) + for d in cast(ApplyOp, op).dest: + effects.add(EffectInstance(MemoryEffectKind.WRITE, d)) + for o in cast(ApplyOp, op).args: + if isinstance(o.type, FieldType): + effects.add(EffectInstance(MemoryEffectKind.READ, o)) return effects @@ -561,17 +561,20 @@ def parse_operand(): def get( args: Sequence[SSAValue] | Sequence[Operation], body: Block | Region, - result_types: Sequence[TempType[Attribute]], + result_types: Sequence[TempType[Attribute]] = (), + bounds: StencilBoundsAttr | None = None, ): - assert len(result_types) > 0 - + assert result_types or bounds if isinstance(body, Block): body = Region(body) + properties = {"bounds": bounds} if bounds else {} + return ApplyOp.build( operands=[list(args), []], regions=[body], result_types=[result_types], + properties=properties, ) def verify_(self) -> None: @@ -646,6 +649,14 @@ def get_accesses(self) -> Iterable[AccessPattern]: accesses.append(offsets) yield AccessPattern(tuple(accesses)) + def get_bounds(self): + if self.bounds is not None: + return self.bounds + else: + assert self.res + res_type = cast(TempType[Attribute], self.res[0].type) + return res_type.bounds + class AllocOpEffect(MemoryEffect): @classmethod @@ -1132,6 +1143,12 @@ def get_apply(self): return cast(ApplyOp, ancestor) +class LoadOpMemoryEffect(MemoryEffect): + @classmethod + def get_effects(cls, op: Operation): + return {EffectInstance(MemoryEffectKind.READ, cast(LoadOp, op).field)} + + @irdl_op_definition class LoadOp(IRDLOperation): """ @@ -1170,7 +1187,7 @@ class LoadOp(IRDLOperation): assembly_format = "$field attr-dict-with-keyword `:` type($field) `->` type($res)" - traits = frozenset([MemoryReadEffect()]) + traits = frozenset([LoadOpMemoryEffect()]) @staticmethod def get( @@ -1294,6 +1311,12 @@ def verify( super().verify(attr, constraint_context) +class StoreOpMemoryEffect(MemoryEffect): + @classmethod + def get_effects(cls, op: Operation): + return {EffectInstance(MemoryEffectKind.WRITE, cast(StoreOp, op).field)} + + @irdl_op_definition class StoreOp(IRDLOperation): """ @@ -1346,6 +1369,8 @@ class StoreOp(IRDLOperation): assembly_format = "$temp `to` $field `` `(` $bounds `)` attr-dict-with-keyword `:` type($temp) `to` type($field)" + traits = frozenset([StoreOpMemoryEffect()]) + @staticmethod def get( temp: SSAValue | Operation, diff --git a/xdsl/transforms/canonicalization_patterns/stencil.py b/xdsl/transforms/canonicalization_patterns/stencil.py index b3bed37913..8cee81a291 100644 --- a/xdsl/transforms/canonicalization_patterns/stencil.py +++ b/xdsl/transforms/canonicalization_patterns/stencil.py @@ -1,7 +1,7 @@ from typing import cast from xdsl.dialects import stencil -from xdsl.ir import Attribute, Block, SSAValue +from xdsl.ir import Attribute, Block, Region, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, @@ -95,8 +95,12 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N results.pop(i) return_args.pop(i) - new = stencil.ApplyOp.get( - op.args, block, [cast(stencil.TempType[Attribute], r.type) for r in results] + new = stencil.ApplyOp.build( + operands=[op.args, op.dest], + regions=[Region(block)], + result_types=[[cast(stencil.TempType[Attribute], r.type) for r in results]], + properties=op.properties.copy(), + attributes=op.attributes.copy(), ) replace_results: list[SSAValue | None] = list(new.res) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 72bd54db43..9b8fbf84f5 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -12,6 +12,7 @@ FieldType, IndexAttr, LoadOp, + ReturnOp, StencilBoundsAttr, StoreOp, TempType, @@ -32,7 +33,8 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint -from xdsl.traits import is_side_effect_free +from xdsl.traits import MemoryEffectKind, get_effects +from xdsl.transforms.canonicalization_patterns.stencil import ApplyUnusedResults from xdsl.transforms.dead_code_elimination import RemoveUnusedOperations from xdsl.utils.hints import isa @@ -43,15 +45,24 @@ def field_from_temp(temp: TempType[_TypeElement]) -> FieldType[_TypeElement]: return FieldType[_TypeElement].new(temp.parameters) +def might_effect( + operation: Operation, effects: set[MemoryEffectKind], value: SSAValue +) -> bool: + """ + Return True if the operation might have any of the given effects on the given value. + """ + op_effects = get_effects(operation) + return op_effects is None or any( + e.kind in effects and e.value in (None, value) for e in op_effects + ) + + class ApplyBufferizePattern(RewritePattern): """ Naive partial `stencil.apply` bufferization. - Just replace all operands with the field result of a stencil.buffer on them, meaning - "The buffer those value are allocated to"; and allocate buffers for every result, - loading them back after the apply, to keep types fine with users. - - Point is to fold as much as possible all the allocations and loads. + Just replace all temp arguments with the field result of a stencil.buffer on them, meaning + "The buffer those value are allocated to". Example: ```mlir @@ -62,26 +73,20 @@ class ApplyBufferizePattern(RewritePattern): yields: ```mlir %in_buf = stencil.buffer %in : !stencil.temp<[0,32]xf64> -> !stencil.field<[0,32]xf64> - %out_buf = stencil.alloc : !stencil.field<[0,32]>xf64 stencil.apply(%0 = %in_buf : !stencil.field<[0,32]>xf64) outs (%out_buf : !stencil.field<[0,32]>xf64) { // [...] } - %out = stencil.load %out_buf : !stencil.field<[0,32]>xf64 -> !stencil.temp<[0,32]>xf64 ``` """ @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): - if not op.res: + if all(not isinstance(o.type, TempType) for o in op.args): return - bounds = cast(TempType[Attribute], op.res[0].type).bounds + bounds = op.get_bounds() - dests = [ - AllocOp(result_types=[field_from_temp(cast(TempType[Attribute], r.type))]) - for r in op.res - ] - operands = [ + args = [ ( BufferOp.create( operands=[o], @@ -90,28 +95,18 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): if isa(o.type, TempType[Attribute]) else o ) - for o in op.operands - ] - - loads = [ - LoadOp(operands=[d], result_types=[r.type]) for d, r in zip(dests, op.res) + for o in op.args ] new = ApplyOp( - operands=[operands, dests], - regions=[Region(Block(arg_types=[SSAValue.get(a).type for a in operands]))], - result_types=[[]], + operands=[args, op.dest], + regions=[op.detach_region(0)], + result_types=[op.res.types], properties={"bounds": bounds}, ) - rewriter.inline_block( - op.region.block, - InsertPoint.at_start(new.region.block), - new.region.block.args, - ) rewriter.replace_matched_op( - [*(o for o in operands if isinstance(o, Operation)), *dests, new, *loads], - [SSAValue.get(l) for l in loads], + [*(o for o in args if isinstance(o, Operation)), new] ) @@ -182,9 +177,7 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter): effecting = [ o for o in walk_from_to(load, user) - if underlying in o.operands - and (not is_side_effect_free(o)) - and (o not in (load, op, user)) + if might_effect(o, {MemoryEffectKind.WRITE}, underlying) ] if effecting: return @@ -192,100 +185,119 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter): rewriter.replace_matched_op(new_ops=[], new_results=[underlying]) -class ApplyLoadStoreFoldPattern(RewritePattern): +class ApplyStoreFoldPattern(RewritePattern): """ - If an allocated field is only used by an apply to write its output and loaded - to be stored in a destination field, make the apply work on the destination directly. + Fold stores of applys result Example: ```mlir - %temp = stencil.alloc : !stencil.field<[0,32]> - stencil.apply() outs (%temp : !stencil.field<[0,32]>) { + %temp = stencil.apply() -> (!stencil.temp<[0,32]>) { // [...] } - // [... %temp, %dest not affected] - %loaded = stencil.load %temp : !stencil.field<[0,32]> -> !stencil.temp<[0,32]> - // [... %dest not affected] - stencil.store %loaded to %dest (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[-2,34]> + // [... %dest not read] + stencil.store %temp to %dest (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[-2,34]> ``` yields: ```mlir - // Will be simplified away by the canonicalizer - %temp = stencil.alloc : !stencil.field<[0,32]> - // Outputs on dest - stencil.apply() outs (%dest : !stencil.field<[0,32]>) { + // Outputs on dest directly + stencil.apply() outs (%dest : !stencil.field<[-2,34]>) { // [...] } - // Load same values from %dest instead for next operations - %loaded = stencil.load %dest : !stencil.field<[0,32]> -> !stencil.temp<[0,32]> ``` """ - @op_type_rewrite_pattern - def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): - temp = op.temp - - # We are looking for a loaded destination of an apply - if not isinstance(load := temp.owner, LoadOp): - return - - infield = load.field - - other_uses = [u for u in infield.uses if u.operation is not load] - - if len(other_uses) != 1: - return - - other_use = other_uses.pop() - - if not isinstance( - apply := other_use.operation, ApplyOp - ) or other_use.index < len(apply.args): - print(other_use) - print() - return - - # Get first occurence of the field, to walk from it - start = op.field.owner + @staticmethod + def is_dest_safe(apply: ApplyOp, store: StoreOp) -> bool: + # Check that the destination is not used between the apply and store. + dest = store.field + start = dest.owner if isinstance(start, Block): - if start is not op.parent: - return start = cast(Operation, start.first_op) effecting = [ o - for o in walk_from_to(start, op) - if infield in o.operands - and (not is_side_effect_free(o)) - and (o not in (load, apply)) + for o in walk_from_to(apply, store) + if might_effect(o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest) ] - if effecting: - print("effecting: ", effecting) - print(load) - return + return not effecting - new_operands = list(apply.operands) - new_operands[other_use.index] = op.field + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): + apply = op + for temp_index, stored in enumerate(op.res): + # We are looking for a result that is stored and foldable + stores = [ + use.operation + for use in stored.uses + if isinstance(use.operation, StoreOp) + and self.is_dest_safe(apply, use.operation) + ] + if not stores: + continue - new_apply = ApplyOp.create( - operands=new_operands, - result_types=[], - properties=apply.properties.copy(), - attributes=apply.attributes.copy(), - regions=[ - apply.detach_region(0), - ], - ) + bounds = apply.get_bounds() + if not isinstance(bounds, StencilBoundsAttr): + raise ValueError( + "Stencil shape inference must be ran before bufferization." + ) - new_load = LoadOp.create( - operands=[op.field], - result_types=load.result_types, - attributes=load.attributes.copy(), - properties=load.properties.copy(), - ) + new_apply = ApplyOp.build( + # We add new destinations for each store of the removed result + operands=[ + apply.args, + (*apply.dest, *(store.field for store in stores)), + ], + # We only remove the considered result + result_types=[ + [ + r.type + for r in apply.results[:temp_index] + + apply.results[temp_index + 1 :] + ] + ], + properties=apply.properties.copy() | {"bounds": bounds}, + attributes=apply.attributes.copy(), + # The block signature is the same + regions=[ + Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])), + ], + ) + + # The body is the same + rewriter.inline_block( + apply.region.block, + InsertPoint.at_start(new_apply.region.block), + new_apply.region.block.args, + ) + + # We swap the return's operand order, to make sure the order still matches destinations + # after bufferization + old_return = new_apply.region.block.last_op + assert isinstance(old_return, ReturnOp) + uf = old_return.unroll_factor + new_return_args = list( + old_return.arg[: uf * temp_index] + + old_return.arg[uf * (temp_index + 1) :] + + old_return.arg[uf * temp_index : uf * (temp_index + 1)] * len(stores) + ) + new_return = ReturnOp.create( + operands=new_return_args, + properties=old_return.properties.copy(), + attributes=old_return.attributes.copy(), + ) + rewriter.replace_op(old_return, new_return) + + # Create a load of a destination, for any other user of the result + load = LoadOp.get(stores[0].field, bounds.lb, bounds.ub) - rewriter.replace_op(apply, new_apply) - rewriter.replace_op(load, new_load) - rewriter.erase_op(op) + rewriter.replace_matched_op( + [new_apply, load], + new_apply.results[:temp_index] + + (load.res,) + + new_apply.results[temp_index:], + ) + for store in stores: + rewriter.erase_op(store) + return @dataclass(frozen=True) @@ -505,8 +517,9 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: BufferAlloc(), CombineStoreFold(), LoadBufferFoldPattern(), - ApplyLoadStoreFoldPattern(), + ApplyStoreFoldPattern(), RemoveUnusedOperations(), + ApplyUnusedResults(), ] ) )