Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
n-io committed Aug 19, 2024
1 parent f791473 commit 6e3604b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
47 changes: 45 additions & 2 deletions tests/filecheck/transforms/stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: xdsl-opt %s -p "stencil-to-csl-stencil{num_chunks=2}" | filecheck %s

builtin.module {
// CHECK-NEXT: builtin.module {

func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
Expand Down Expand Up @@ -32,9 +33,7 @@ builtin.module {
stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}
}

// CHECK-NEXT: builtin.module {
// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32>
Expand Down Expand Up @@ -66,4 +65,48 @@ builtin.module {
// CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }


func.func @bufferized(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
"dmp.swap"(%a) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 510] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 510] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 510] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 510] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()
%0 = stencil.apply(%1 = %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%2 = arith.constant dense<1.666600e-01> : tensor<510xf32>
%3 = stencil.access %1[1, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
%4 = "tensor.extract_slice"(%3) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%5 = stencil.access %1[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
%6 = "tensor.extract_slice"(%5) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%7 = stencil.access %1[0, -1] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
%8 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%9 = arith.addf %8, %6 : tensor<510xf32>
%10 = arith.addf %9, %4 : tensor<510xf32>
%11 = arith.mulf %10, %2 : tensor<510xf32>
stencil.return %11 : tensor<510xf32>
} to <[0, 0], [1, 1]>
stencil.store %0 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}

// CHECK-NEXT: func.func @bufferized(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0>}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>):
// CHECK-NEXT: %5 = csl_stencil.access %2[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %6 = csl_stencil.access %2[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %7 = arith.addf %6, %5 : tensor<255xf32>
// CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %4, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %8 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%9 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %10 : tensor<510xf32>):
// CHECK-NEXT: %11 = csl_stencil.access %9[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %12 = arith.constant dense<1.666600e-01> : tensor<510xf32>
// CHECK-NEXT: %13 = "tensor.extract_slice"(%11) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %14 = arith.addf %10, %13 : tensor<510xf32>
// CHECK-NEXT: %15 = arith.mulf %14, %12 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %15 : tensor<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }

}
// CHECK-NEXT: }
8 changes: 6 additions & 2 deletions xdsl/transforms/stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,12 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
),
)

# a little hack to get around a check that prevents replacing a no-results op with an n-results op
rewriter.replace_matched_op(prefetch_op, new_results=[op.input_stencil])
# if the rewriter needs a result, use `input_stencil` as a drop-in replacement
# prefetch_op produces a result that needs to be handled separately
# note, that only un-bufferized dmp.swaps produce a result
rewriter.replace_matched_op(
prefetch_op, new_results=[op.input_stencil] if op.swapped_values else []
)

# uses have to be retrieved *before* the loop because of the rewriting happening inside the loop
uses = list(op.input_stencil.uses)
Expand Down

0 comments on commit 6e3604b

Please sign in to comment.