Skip to content

Commit

Permalink
simplify accelerator templates (#359)
Browse files Browse the repository at this point in the history
* simplify accelerator templates

* fix filecheck
  • Loading branch information
jorendumoulin authored Feb 5, 2025
1 parent 2a0442e commit cfa8283
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 14 deletions.
4 changes: 2 additions & 2 deletions compiler/accelerators/snax_alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,6 @@ def generate_acc_op(self) -> accfg.AcceleratorOp:

@staticmethod
def get_template(op: dart.StreamingRegionOpBase):
template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3
template_bounds = (None, 4)
template = [AffineMap.from_callable(lambda y: (y,))] * 3
template_bounds = (4,)
return Template(TemplatePattern(template_bounds, tp) for tp in template)
18 changes: 9 additions & 9 deletions compiler/accelerators/snax_gemmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ def get_template(op: dart.StreamingRegionOpBase) -> Template:
assert isinstance(generic_op := op.body.block.first_op, dart.GenericOp)
if isinstance(generic_op.body.block.first_op, kernel.QMacOp):
# matmul
M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6))
m, n, k = (AffineDimExpr(i) for i in range(3))
template = [
AffineMap(6, 0, (M * 8 + m, K * 8 + k)),
AffineMap(6, 0, (K * 8 + k, N * 8 + n)),
AffineMap(6, 0, (M * 8 + m, N * 8 + n)),
AffineMap(3, 0, (m, k)),
AffineMap(3, 0, (k, n)),
AffineMap(3, 0, (m, n)),
]
template_bounds = (None, None, None, 8, 8, 8)
template_bounds = (8, 8, 8)

if isinstance(generic_op.next_op, dart.GenericOp):
generic_op = generic_op.next_op
Expand All @@ -292,12 +292,12 @@ def get_template(op: dart.StreamingRegionOpBase) -> Template:
raise RuntimeError("unsupported kernel")
else:
# rescale only function of gemmx
M, K, m, k = (AffineDimExpr(i) for i in range(4))
m, k = (AffineDimExpr(i) for i in range(2))
template = [
AffineMap(4, 0, (M * 8 + m, K * 8 + k)),
AffineMap(4, 0, (M * 8 + m, K * 8 + k)),
AffineMap(2, 0, (m, k)),
AffineMap(2, 0, (m, k)),
]
template_bounds = (None, None, 8, 8)
template_bounds = (8, 8)

if not isinstance(generic_op.next_op, dart.YieldOp):
raise RuntimeError("unsupported kernel")
Expand Down
5 changes: 4 additions & 1 deletion compiler/ir/dart/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def is_pure_output_stationary(template: Template, schedule: Schedule):
def scheduler(
template: Template,
schedule: Schedule,
extra_checks: Sequence[Callable[[Template, Schedule], bool]] = [],
extra_checks: Sequence[Callable[[Template, Schedule], bool]] = [
# defaulting to pure output stationary schedules for now
is_pure_output_stationary
],
) -> Schedule:
# for now just return the first result of the backtracking
result = next(scheduler_backtrack(template, schedule, extra_checks=extra_checks))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func.func @streamer_matmul(%arg0 : memref<16x16xi8>, %arg1 : memref<16x16xi8, st
func.return
}

// CHECK: "snax_stream.streaming_region"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{stride_patterns = [#snax_stream.stride_pattern<ub = [2, 2, 2], ts = [8, 0, 128], ss = [8]>, #snax_stream.stride_pattern<ub = [2, 2, 2], ts = [8, 128, 0], ss = [8]>, #snax_stream.stride_pattern<ub = [0, 0, 0], ts = [0, 0, 0], ss = [0]>, #snax_stream.stride_pattern<ub = [2, 2, 2], ts = [0, 0, 0], ss = [8, 64]>, #snax_stream.stride_pattern<ub = [2, 2, 2], ts = [0, 32, 512], ss = [8, 64]>], accelerator = "snax_gemmx", operandSegmentSizes = array<i32: 4, 1>}> ({
// CHECK: "snax_stream.streaming_region"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{stride_patterns = [#snax_stream.stride_pattern<ub = [2, 2, 2], ts = [8, 128, 0], ss = [8]>, #snax_stream.stride_pattern<ub = [2, 2, 2], ts = [8, 0, 128], ss = [8]>, #snax_stream.stride_pattern<ub = [0, 0, 0], ts = [0, 0, 0], ss = [0]>, #snax_stream.stride_pattern<ub = [2, 2, 2], ts = [0, 0, 0], ss = [8, 64]>, #snax_stream.stride_pattern<ub = [2, 2, 2], ts = [0, 512, 32], ss = [8, 64]>], accelerator = "snax_gemmx", operandSegmentSizes = array<i32: 4, 1>}> ({
// CHECK-NEXT: ^0(%{{.*}} : !dart.stream<i8>, %{{.*}} : !dart.stream<i8>, %{{.*}} : !dart.stream<i32>):
// CHECK-NEXT: %{{.*}} = "dart.generic"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{library_call = "snax_gemmx"}> ({
// CHECK-NEXT: ^1(%{{.*}} : i8, %{{.*}} : i8, %{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32):
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/transforms/dart/dart-scheduler.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func.func @streamer_matmul(%arg0 : memref<16x16xi8>, %arg1 : memref<16x16xi8, st
func.return
}

// CHECK: "dart.schedule"(%arg0, %arg1, %arg2) <{patterns = [affine_map<(d0, d1, d2, d3, d4, d5) -> (((d0 * 8) + d3), ((d2 * 8) + d5))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d2 * 8) + d5), ((d1 * 8) + d4))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d0 * 8) + d3), ((d1 * 8) + d4))>], accelerator = "snax_gemmx", tiles = [[]], bounds = [2 : index, 2 : index, 2 : index, 8 : index, 8 : index, 8 : index], operandSegmentSizes = array<i32: 2, 1>}> ({
// CHECK: "dart.schedule"(%arg0, %arg1, %arg2) <{patterns = [affine_map<(d0, d1, d2, d3, d4, d5) -> (((d1 * 8) + d3), ((d2 * 8) + d5))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d2 * 8) + d5), ((d0 * 8) + d4))>, affine_map<(d0, d1, d2, d3, d4, d5) -> (((d1 * 8) + d3), ((d0 * 8) + d4))>], accelerator = "snax_gemmx", tiles = [[]], bounds = [2 : index, 2 : index, 2 : index, 8 : index, 8 : index, 8 : index], operandSegmentSizes = array<i32: 2, 1>}> (
// CHECK-NEXT: ^0(%1 : !dart.stream<i8>, %2 : !dart.stream<i8>, %3 : !dart.stream<i32>):
// CHECK-NEXT: %4 = "dart.generic"(%1, %2, %0, %0) <{library_call = "snax_gemmx"}> ({
// CHECK-NEXT: ^1(%arg3 : i8, %arg4 : i8, %arg5 : i32, %arg6 : i32, %arg7 : i32):
Expand Down

0 comments on commit cfa8283

Please sign in to comment.