diff --git a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir index 662b86b8ab..24cc1c68c0 100644 --- a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir +++ b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir @@ -4,7 +4,7 @@ builtin.module { func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (tensor<4x510xf32>) + %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "num_chunks" = 2 : i64}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (tensor<4x510xf32>) %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 = %pref : tensor<4x510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { %4 = arith.constant 1.666600e-01 : f32 %5 = csl_stencil.access %3[1, 0] : tensor<4x510xf32> @@ -33,7 +33,7 @@ builtin.module { // CHECK: builtin.module { // CHECK-NEXT: func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32> +// CHECK-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "num_chunks" = 2 : i64}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32> // CHECK-NEXT: %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 = %pref : tensor<4x510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { // CHECK-NEXT: %4 = arith.constant 1.666600e-01 : f32 // CHECK-NEXT: %5 = csl_stencil.access %3[1, 0] : tensor<4x510xf32> @@ -63,7 +63,7 @@ builtin.module { // CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel_func", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({ // CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>): // CHECK-GENERIC-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-GENERIC-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32> +// CHECK-GENERIC-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "num_chunks" = 2 : i64}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<4x510xf32> // CHECK-GENERIC-NEXT: %1 = "stencil.apply"(%0, %pref) <{"operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^1(%2 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 : tensor<4x510xf32>): // CHECK-GENERIC-NEXT: %4 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32 diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index f6164c2488..9b5012e3ef 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -3,217 +3,6 @@ builtin.module { // CHECK-NEXT: builtin.module { - func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { - %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %24 = "dmp.swap"(%0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) - %1 = stencil.apply(%2 = %24 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { - %3 = arith.constant 1.666600e-01 : f32 - %4 = stencil.access %2[1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %5 = "tensor.extract_slice"(%4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %6 = stencil.access %2[-1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %7 = "tensor.extract_slice"(%6) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %8 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %9 = "tensor.extract_slice"(%8) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %10 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %12 = stencil.access %2[0, 1] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %13 = "tensor.extract_slice"(%12) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %14 = stencil.access %2[0, -1] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - %15 = "tensor.extract_slice"(%14) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %16 = arith.addf %15, %13 : tensor<510xf32> - %17 = arith.addf %16, %11 : tensor<510xf32> - %18 = arith.addf %17, %9 : tensor<510xf32> - %19 = arith.addf %18, %7 : tensor<510xf32> - %20 = arith.addf %19, %5 : tensor<510xf32> - %21 = tensor.empty() : tensor<510xf32> - %22 = linalg.fill ins(%3 : f32) outs(%21 : tensor<510xf32>) -> tensor<510xf32> - %23 = arith.mulf %20, %22 : tensor<510xf32> - stencil.return %23 : tensor<510xf32> - } - stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - func.return - } - -// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { -// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^0(%3 : tensor<4x255xf32>, %4 : index, %5 : tensor<510xf32>): -// CHECK-NEXT: %6 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: %7 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %8 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %9 = csl_stencil.access %3[0, 1] : tensor<4x255xf32> -// CHECK-NEXT: %10 = csl_stencil.access %3[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %11 = arith.addf %10, %9 : tensor<255xf32> -// CHECK-NEXT: %12 = arith.addf %11, %8 : tensor<255xf32> -// CHECK-NEXT: %13 = arith.addf %12, %7 : tensor<255xf32> -// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %14 : tensor<510xf32> -// CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%15 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %16 : tensor<510xf32>): -// CHECK-NEXT: %17 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %20 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %22 = arith.addf %16, %21 : tensor<510xf32> -// CHECK-NEXT: %23 = arith.addf %22, %19 : tensor<510xf32> -// CHECK-NEXT: %24 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %25 = linalg.fill ins(%17 : f32) outs(%24 : tensor<510xf32>) -> tensor<510xf32> -// CHECK-NEXT: %26 = arith.mulf %23, %25 : tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %26 : tensor<510xf32> -// CHECK-NEXT: }) -// CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: func.return -// CHECK-NEXT: } - - - func.func @bufferized(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { - "dmp.swap"(%a) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> () - %0 = stencil.apply(%1 = %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { - %2 = arith.constant dense<1.666600e-01> : tensor<510xf32> - %3 = stencil.access %1[1, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - %4 = "tensor.extract_slice"(%3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %5 = stencil.access %1[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - %6 = "tensor.extract_slice"(%5) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %7 = stencil.access %1[0, -1] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %9 = arith.addf %8, %6 : tensor<510xf32> - %10 = arith.addf %9, %4 : tensor<510xf32> - %11 = arith.mulf %10, %2 : tensor<510xf32> - stencil.return %11 : tensor<510xf32> - } to <[0, 0], [1, 1]> - stencil.store %0 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - func.return - } - -// CHECK-NEXT: func.func @bufferized(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { -// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): -// CHECK-NEXT: %5 = arith.constant dense<1.666600e-01> : tensor<510xf32> -// CHECK-NEXT: %6 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %7 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %8 = arith.addf %7, %6 : tensor<255xf32> -// CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32> -// CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%10 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %11 : tensor<510xf32>): -// CHECK-NEXT: %12 = arith.constant dense<1.666600e-01> : tensor<510xf32> -// CHECK-NEXT: %13 = csl_stencil.access %10[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %15 = arith.addf %11, %14 : tensor<510xf32> -// CHECK-NEXT: %16 = arith.mulf %15, %12 : tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32> -// CHECK-NEXT: }) to <[0, 0], [1, 1]> -// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: func.return -// CHECK-NEXT: } - - func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { - "dmp.swap"(%a) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> () - %0 = stencil.apply(%1 = %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { - %2 = arith.constant dense<1.234500e-01> : tensor<510xf32> - %3 = arith.constant dense<2.345678e-01> : tensor<510xf32> - %4 = arith.constant dense<3.141500e-01> : tensor<510xf32> - %5 = stencil.access %1[1, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - %6 = "tensor.extract_slice"(%5) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %7 = stencil.access %1[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %9 = stencil.access %1[0, -1] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - %10 = "tensor.extract_slice"(%9) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> - %11 = arith.mulf %6, %3 : tensor<510xf32> - %12 = arith.mulf %10, %4 : tensor<510xf32> - %13 = arith.addf %12, %8 : tensor<510xf32> - %14 = arith.addf %13, %11 : tensor<510xf32> - %15 = arith.mulf %14, %2 : tensor<510xf32> - stencil.return %15 : tensor<510xf32> - } to <[0, 0], [1, 1]> - stencil.store %0 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> - func.return - } - -// CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { -// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({ -// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): -// CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32> -// CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32> -// CHECK-NEXT: %7 = arith.constant dense<3.141500e-01> : tensor<510xf32> -// CHECK-NEXT: %8 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %9 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %10 = arith.addf %9, %8 : tensor<255xf32> -// CHECK-NEXT: %11 = "tensor.insert_slice"(%10, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32> -// CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%12 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %13 : tensor<510xf32>): -// CHECK-NEXT: %14 = arith.constant dense<1.234500e-01> : tensor<510xf32> -// CHECK-NEXT: %15 = csl_stencil.access %12[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %17 = arith.addf %13, %16 : tensor<510xf32> -// CHECK-NEXT: %18 = arith.mulf %17, %14 : tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %18 : tensor<510xf32> -// CHECK-NEXT: }) to <[0, 0], [1, 1]> -// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: func.return -// CHECK-NEXT: } - - func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { - %0 = arith.constant 41 : index - %1 = arith.constant 0 : index - %2 = arith.constant 1 : index - %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { - "dmp.swap"(%arg3) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<600x600>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) -> () - stencil.apply(%arg5 = %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { - %5 = arith.constant dense<1.287158e+09> : tensor<600xf32> - %6 = arith.constant dense<1.196003e+05> : tensor<600xf32> - %7 = stencil.access %arg5[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %9 = arith.mulf %8, %5 : tensor<600xf32> - %10 = stencil.access %arg5[-1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %12 = arith.mulf %11, %6 : tensor<600xf32> - %13 = stencil.access %arg5[1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %15 = arith.mulf %14, %6 : tensor<600xf32> - %16 = arith.addf %12, %9 : tensor<600xf32> - %17 = arith.addf %16, %15 : tensor<600xf32> - stencil.return %17 : tensor<600xf32> - } to <[0, 0], [1, 1]> - scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - } - func.return -} - -// CHECK-NEXT: func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { -// CHECK-NEXT: %0 = arith.constant 41 : index -// CHECK-NEXT: %1 = arith.constant 0 : index -// CHECK-NEXT: %2 = arith.constant 1 : index -// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { -// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> -// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({ -// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): -// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> -// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32> -// CHECK-NEXT: %11 = csl_stencil.access %6[-1, 0] : tensor<8x300xf32> -// CHECK-NEXT: %12 = csl_stencil.access %6[1, 0] : tensor<8x300xf32> -// CHECK-NEXT: %13 = arith.addf %11, %12 : tensor<300xf32> -// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %14 : tensor<600xf32> -// CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%15 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %16 : tensor<600xf32>): -// CHECK-NEXT: %17 = arith.constant dense<1.287158e+09> : tensor<600xf32> -// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> -// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %20 = arith.mulf %19, %17 : tensor<600xf32> -// CHECK-NEXT: %21 = arith.addf %16, %20 : tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %21 : tensor<600xf32> -// CHECK-NEXT: }) to <[0, 0], [1, 1]> -// CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> -// CHECK-NEXT: } -// CHECK-NEXT: func.return -// CHECK-NEXT: } - func.func @uvbke(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) { "dmp.swap"(%arg0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<64x64>, false>, "swaps" = [#dmp.exchange]} : (!stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) -> () "dmp.swap"(%arg1) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<64x64>, false>, "swaps" = [#dmp.exchange]} : (!stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) -> () @@ -235,7 +24,16 @@ builtin.module { } // CHECK-NEXT: func.func @uvbke(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) { -// CHECK-NEXT: %0 = "csl_stencil.prefetch"(%arg1) <{"topo" = #dmp.topo<64x64>, "swaps" = [#csl_stencil.exchange]}> : (!stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) -> tensor<1x64xf32> +// CHECK-NEXT: %0 = tensor.empty() : tensor<1x64xf32> +// CHECK-NEXT: csl_stencil.apply(%arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) -> () <{"swaps" = [#csl_stencil.exchange], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^0(%1 : tensor<1x32xf32>, %2 : index, %3 : tensor<1x64xf32>): +// CHECK-NEXT: %4 = csl_stencil.access %3[-1, 0] : tensor<1x64xf32> +// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32> +// CHECK-NEXT: csl_stencil.yield %5 : tensor<1x64xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%6 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %7 : tensor<1x64xf32>): +// CHECK-NEXT: csl_stencil.yield %7 : tensor<1x64xf32> +// CHECK-NEXT: }) // CHECK-NEXT: %1 = tensor.empty() : tensor<64xf32> // CHECK-NEXT: csl_stencil.apply(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %1 : tensor<64xf32>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) outs (%arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) <{"swaps" = [#csl_stencil.exchange], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%2 : tensor<1x32xf32>, %3 : index, %4 : tensor<64xf32>): diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index bd1168d94f..a7eabf9750 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -140,12 +140,15 @@ class PrefetchOp(IRDLOperation): topo = prop_def(dmp.RankTopoAttr) + num_chunks = prop_def(AnyIntegerAttr) + result = result_def(AnyMemRefTypeConstr | AnyTensorTypeConstr) def __init__( self, input_stencil: SSAValue | Operation, topo: dmp.RankTopoAttr, + num_chunks: AnyIntegerAttr, swaps: Sequence[ExchangeDeclarationAttr], result_type: memref.MemRefType[Attribute] | TensorType[Attribute] | None = None, ): @@ -154,6 +157,7 @@ def __init__( properties={ "topo": topo, "swaps": builtin.ArrayAttr(swaps), + "num_chunks": num_chunks, }, result_types=[result_type], ) @@ -368,7 +372,7 @@ def verify_(self) -> None: self.accumulator.type.get_element_type(), ( len(self.swaps), - self.accumulator.type.get_shape()[0] // self.num_chunks.value.data, + self.accumulator.type.get_shape()[-1] // self.num_chunks.value.data, ), ), IndexType(), @@ -393,16 +397,18 @@ def verify_(self) -> None: f"Unexpected block argument type of done_exchange, got {arg.type} != {expected_type} at index {arg.index}" ) - if (len(self.res) == 0) == (len(self.dest) == 0): + if (len(self.res) > 0) and (len(self.dest) > 0): raise VerifyException( - "Expected stencil.apply to have either results or dest specified" + "Cannot specify both results and dest on stencil.apply" ) def get_rank(self) -> int: if self.dest: res_type = self.dest[0].type - else: + elif self.res: res_type = self.res[0].type + else: + return 2 if isattr(res_type, stencil.StencilTypeConstr): return res_type.get_num_dims() elif self.bounds: diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 531bceb932..76981530d4 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -2,8 +2,9 @@ from dataclasses import dataclass from math import prod +from xdsl.builder import ImplicitBuilder from xdsl.context import MLContext -from xdsl.dialects import arith, stencil, tensor, varith +from xdsl.dialects import arith, builtin, memref, stencil, tensor, varith from xdsl.dialects.builtin import ( AnyFloatAttr, AnyMemRefTypeConstr, @@ -145,12 +146,14 @@ def match_and_rewrite(self, op: stencil.AccessOp, rewriter: PatternRewriter, /): rewriter.replace_op(use, [], new_results=[new_access_op.result]) -@dataclass(frozen=True) +@dataclass class ConvertSwapToPrefetchPattern(RewritePattern): """ Translates dmp.swap to csl_stencil.prefetch """ + num_chunks: int = 1 + @op_type_rewrite_pattern def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): # remove op if it contains no swaps @@ -187,6 +190,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): prefetch_op = csl_stencil.PrefetchOp( input_stencil=op.input_stencil, topo=op.strategy.comm_layout(), + num_chunks=IntegerAttr(self.num_chunks, 64), swaps=[ csl_stencil.ExchangeDeclarationAttr(swap.neighbor[:2]) for swap in op.swaps @@ -577,6 +581,63 @@ def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter, rewriter.replace_op(mulf, [], new_results=[op.result]) +class TransformPrefetch(RewritePattern): + """ + Rewrites a prefetch into a communicate-only csl_stencil.apply + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: csl_stencil.PrefetchOp, rewriter: PatternRewriter, / + ): + a_buf = tensor.EmptyOp((), op.result.type) + # because we are building a set of offsets, we are not retaining offset mappings + offsets = [swap.neighbor for swap in op.swaps] + + assert isa(op.result.type, AnyTensorType) + chunk_buf_t = TensorType( + op.result.type.get_element_type(), + ( + len(op.swaps), + op.result.type.get_shape()[1] // op.num_chunks.value.data, + ), + ) + chunk_t = TensorType(chunk_buf_t.element_type, chunk_buf_t.get_shape()[1:]) + + block = Block(arg_types=[chunk_buf_t, builtin.IndexType(), op.result.type]) + block2 = Block(arg_types=[op.input_stencil.type, op.result.type]) + block2.add_op(csl_stencil.YieldOp(block2.args[1])) + + with ImplicitBuilder(block) as (_, offset, acc): + dest = acc + for i, acc_offset in enumerate(offsets): + ac_op = csl_stencil.AccessOp( + dest, stencil.IndexAttr.get(*acc_offset), chunk_t + ) + assert isa(ac_op.result.type, AnyTensorType) + dest = tensor.InsertSliceOp.get( + source=ac_op.result, + dest=dest, + static_sizes=ac_op.result.type.get_shape(), + static_offsets=[i, memref.SubviewOp.DYNAMIC_INDEX], + offsets=[offset], + ).result + csl_stencil.YieldOp(dest) + + apply_op = csl_stencil.ApplyOp( + operands=[op.input_stencil, a_buf, [], [], []], + regions=[Region(block), Region(block2)], + properties={ + "swaps": op.swaps, + "topo": op.topo, + "num_chunks": op.num_chunks, + }, + result_types=[[]], + ) + + rewriter.replace_matched_op([a_buf, apply_op], new_results=[a_buf.tensor]) + + @dataclass(frozen=True) class ConvertStencilToCslStencilPass(ModulePass): name = "convert-stencil-to-csl-stencil" @@ -589,7 +650,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: module_pass = PatternRewriteWalker( GreedyRewritePatternApplier( [ - ConvertSwapToPrefetchPattern(), + ConvertSwapToPrefetchPattern(num_chunks=self.num_chunks), ConvertAccessOpPattern(), ] ), @@ -602,9 +663,11 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: [ ConvertApplyOpPattern(num_chunks=self.num_chunks), PromoteCoefficients(), + TransformPrefetch(), ] ), apply_recursively=False, + walk_reverse=True, ).rewrite_module(op) ConvertVarithToArithPass().apply(ctx, op)