Skip to content

Commit

Permalink
transformations (csl): Add prefetch lowering (#3584)
Browse files Browse the repository at this point in the history
Add a prefetch lowering by rewriting csl_stencil.prefetch into a
communicate-only csl_stencil.apply

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
emmau678 and n-io authored Dec 6, 2024
1 parent f05a23b commit 3874827
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 222 deletions.
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/csl/csl-stencil-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
builtin.module {
func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (tensor<4x510xf32>)
%pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "num_chunks" = 2 : i64}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (tensor<4x510xf32>)
%1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 = %pref : tensor<4x510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%4 = arith.constant 1.666600e-01 : f32
%5 = csl_stencil.access %3[1, 0] : tensor<4x510xf32>
Expand Down Expand Up @@ -33,7 +33,7 @@ builtin.module {
// CHECK: builtin.module {
// CHECK-NEXT: func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32>
// CHECK-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "num_chunks" = 2 : i64}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32>
// CHECK-NEXT: %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 = %pref : tensor<4x510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
// CHECK-NEXT: %4 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %5 = csl_stencil.access %3[1, 0] : tensor<4x510xf32>
Expand Down Expand Up @@ -63,7 +63,7 @@ builtin.module {
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel_func", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({
// CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-GENERIC-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-GENERIC-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32>
// CHECK-GENERIC-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "num_chunks" = 2 : i64}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32>
// CHECK-GENERIC-NEXT: %1 = "stencil.apply"(%0, %pref) <{"operandSegmentSizes" = array<i32: 2, 0>}> ({
// CHECK-GENERIC-NEXT: ^1(%2 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 : tensor<4x510xf32>):
// CHECK-GENERIC-NEXT: %4 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32
Expand Down
Loading

0 comments on commit 3874827

Please sign in to comment.