diff --git a/tests/filecheck/transforms/set-memory-layout.mlir b/tests/filecheck/transforms/set-memory-layout.mlir index 3c0216fc..dff07501 100644 --- a/tests/filecheck/transforms/set-memory-layout.mlir +++ b/tests/filecheck/transforms/set-memory-layout.mlir @@ -1,19 +1,19 @@ // RUN: ./compiler/snax-opt --split-input-file %s -p set-memory-layout --print-op-generic | filecheck %s // RUN: ./compiler/snax-opt --split-input-file %s -p set-memory-layout{gemm_layout=banked} --print-op-generic | filecheck %s --check-prefix=BANKED -func.func @gemm(%arg0 : memref<16x16xi8, "L1">, %arg1 : memref<16x16xi8, "L1">, %arg2 : memref<16x16xi32, "L1">) -> () { - %0 = arith.constant 0 : i32 - %1 = arith.constant 0 : i32 - "dart.operation"(%arg0, %arg1, %arg2) <{"accelerator" = "snax_gemmx", "operandSegmentSizes" = array, "patterns" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]}> ({ - ^0(%arg3 : !dart.stream, %arg4 : !dart.stream, %arg5 : !dart.stream): - %5 = "dart.generic"(%arg3, %arg4, %0, %1) <{"library_call" = "snax_gemmx"}> ({ - ^1(%arg6 : i8, %arg7 : i8, %arg8 : i32, %arg9 : i32, %arg10 : i32): - %6 = kernel.qmac %arg6, %arg7 zp_lhs : %arg8 zp_rhs : %arg9 : i8, i8, i32, i32 -> i32 - dart.yield %6 : i32 - }) : (!dart.stream, !dart.stream, i32, i32) -> !dart.stream - dart.yield %5 : !dart.stream - }) : (memref<16x16xi8, "L1">, memref<16x16xi8, "L1">, memref<16x16xi32, "L1">) -> () - func.return +func.func @gemm(%arg0 : memref<16x16xi8, "L1">, %arg1 : memref<16x16xi8, "L1">, %arg2 : memref<16x16xi32, "L1">) { + %0 = arith.constant 0 : i32 + %1 = arith.constant 0 : i32 + "dart.schedule"(%arg0, %arg1, %arg2) <{patterns = [affine_map<(d0, d1, d2, d3, d4, d5) -> (((d0 * 8) + d3), ((d2 * 8) + d5))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d2 * 8) + d5), ((d1 * 8) + d4))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d0 * 8) + d3), ((d1 * 8) + d4))>], accelerator = "snax_gemmx", tiles = [[]], bounds = [2 : index, 2 : index, 2 : index, 8 : index, 8 : index, 8 : index], operandSegmentSizes = array}> ({ + ^0(%arg3 : !dart.stream, %arg4 : !dart.stream, %arg5 : !dart.stream): + %2 = "dart.generic"(%arg3, %arg4, %0, %1) <{library_call = "snax_gemmx"}> ({ + ^1(%arg6 : i8, %arg7 : i8, %arg8 : i32, %arg9 : i32, %arg10 : i32): + %3 = kernel.qmac %arg6, %arg7 zp_lhs : %arg8 zp_rhs : %arg9 : i8, i8, i32, i32 -> i32 + dart.yield %3 : i32 + }) : (!dart.stream, !dart.stream, i32, i32) -> !dart.stream + dart.yield %2 : !dart.stream + }) : (memref<16x16xi8, "L1">, memref<16x16xi8, "L1">, memref<16x16xi32, "L1">) -> () + func.return } // CHECK: %2 = "snax.layout_cast"(%arg0) : (memref<16x16xi8, "L1">) -> memref<16x16xi8, #tsl.tsl<[2, 8] -> (128, 8), [2, 8] -> (64, 1)>, "L1">