diff --git a/compiler/accelerators/snax_alu.py b/compiler/accelerators/snax_alu.py index a2859136..ba706379 100644 --- a/compiler/accelerators/snax_alu.py +++ b/compiler/accelerators/snax_alu.py @@ -195,6 +195,6 @@ def generate_acc_op(self) -> accfg.AcceleratorOp: @staticmethod def get_template(op: dart.StreamingRegionOpBase): - template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3 - template_bounds = (None, 4) + template = [AffineMap.from_callable(lambda y: (y,))] * 3 + template_bounds = (4,) return Template(TemplatePattern(template_bounds, tp) for tp in template) diff --git a/compiler/accelerators/snax_gemmx.py b/compiler/accelerators/snax_gemmx.py index e0bf9411..4b64133e 100644 --- a/compiler/accelerators/snax_gemmx.py +++ b/compiler/accelerators/snax_gemmx.py @@ -275,13 +275,13 @@ def get_template(op: dart.StreamingRegionOpBase) -> Template: assert isinstance(generic_op := op.body.block.first_op, dart.GenericOp) if isinstance(generic_op.body.block.first_op, kernel.QMacOp): # matmul - M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6)) + m, n, k = (AffineDimExpr(i) for i in range(3)) template = [ - AffineMap(6, 0, (M * 8 + m, K * 8 + k)), - AffineMap(6, 0, (K * 8 + k, N * 8 + n)), - AffineMap(6, 0, (M * 8 + m, N * 8 + n)), + AffineMap(3, 0, (m, k)), + AffineMap(3, 0, (k, n)), + AffineMap(3, 0, (m, n)), ] - template_bounds = (None, None, None, 8, 8, 8) + template_bounds = (8, 8, 8) if isinstance(generic_op.next_op, dart.GenericOp): generic_op = generic_op.next_op @@ -292,12 +292,12 @@ def get_template(op: dart.StreamingRegionOpBase) -> Template: raise RuntimeError("unsupported kernel") else: # rescale only function of gemmx - M, K, m, k = (AffineDimExpr(i) for i in range(4)) + m, k = (AffineDimExpr(i) for i in range(2)) template = [ - AffineMap(4, 0, (M * 8 + m, K * 8 + k)), - AffineMap(4, 0, (M * 8 + m, K * 8 + k)), + AffineMap(2, 0, (m, k)), + AffineMap(2, 0, (m, k)), ] - template_bounds = (None, None, 8, 8) + template_bounds = (8, 8) if not isinstance(generic_op.next_op, dart.YieldOp): raise RuntimeError("unsupported kernel") diff --git a/compiler/ir/dart/scheduler.py b/compiler/ir/dart/scheduler.py index c1fec37a..abd3b9f2 100644 --- a/compiler/ir/dart/scheduler.py +++ b/compiler/ir/dart/scheduler.py @@ -125,7 +125,10 @@ def is_pure_output_stationary(template: Template, schedule: Schedule): def scheduler( template: Template, schedule: Schedule, - extra_checks: Sequence[Callable[[Template, Schedule], bool]] = [], + extra_checks: Sequence[Callable[[Template, Schedule], bool]] = [ + # defaulting to pure output stationary schedules for now + is_pure_output_stationary + ], ) -> Schedule: # for now just return the first result of the backtracking result = next(scheduler_backtrack(template, schedule, extra_checks=extra_checks)) diff --git a/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir b/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir index 130da462..c86e3ad9 100644 --- a/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir +++ b/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir @@ -42,7 +42,7 @@ func.func @streamer_matmul(%arg0 : memref<16x16xi8>, %arg1 : memref<16x16xi8, st func.return } -// CHECK: "snax_stream.streaming_region"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{stride_patterns = [#snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern], accelerator = "snax_gemmx", operandSegmentSizes = array}> ({ +// CHECK: "snax_stream.streaming_region"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{stride_patterns = [#snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern], accelerator = "snax_gemmx", operandSegmentSizes = array}> ({ // CHECK-NEXT: ^0(%{{.*}} : !dart.stream, %{{.*}} : !dart.stream, %{{.*}} : !dart.stream): // CHECK-NEXT: %{{.*}} = "dart.generic"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{library_call = "snax_gemmx"}> ({ // CHECK-NEXT: ^1(%{{.*}} : i8, %{{.*}} : i8, %{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32): diff --git a/tests/filecheck/transforms/dart/dart-scheduler.mlir b/tests/filecheck/transforms/dart/dart-scheduler.mlir index 656e28e0..e2dbdca7 100644 --- a/tests/filecheck/transforms/dart/dart-scheduler.mlir +++ b/tests/filecheck/transforms/dart/dart-scheduler.mlir @@ -14,7 +14,7 @@ func.func @streamer_matmul(%arg0 : memref<16x16xi8>, %arg1 : memref<16x16xi8, st func.return } -// CHECK: "dart.schedule"(%arg0, %arg1, %arg2) <{patterns = [affine_map<(d0, d1, d2, d3, d4, d5) -> (((d0 * 8) + d3), ((d2 * 8) + d5))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d2 * 8) + d5), ((d1 * 8) + d4))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d0 * 8) + d3), ((d1 * 8) + d4))>], accelerator = "snax_gemmx", tiles = [[]], bounds = [2 : index, 2 : index, 2 : index, 8 : index, 8 : index, 8 : index], operandSegmentSizes = array}> ({ +// CHECK: "dart.schedule"(%arg0, %arg1, %arg2) <{patterns = [affine_map<(d0, d1, d2, d3, d4, d5) -> (((d1 * 8) + d3), ((d2 * 8) + d5))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d2 * 8) + d5), ((d0 * 8) + d4))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d1 * 8) + d3), ((d0 * 8) + d4))>], accelerator = "snax_gemmx", tiles = [[]], bounds = [2 : index, 2 : index, 2 : index, 8 : index, 8 : index, 8 : index], operandSegmentSizes = array}> ( // CHECK-NEXT: ^0(%1 : !dart.stream, %2 : !dart.stream, %3 : !dart.stream): // CHECK-NEXT: %4 = "dart.generic"(%1, %2, %0, %0) <{library_call = "snax_gemmx"}> ({ // CHECK-NEXT: ^1(%arg3 : i8, %arg4 : i8, %arg5 : i32, %arg6 : i32, %arg7 : i32):