Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use auto-generated cyclic layouts #361

Merged
merged 7 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/dense_matmul/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ else
REMOVE_MEMREF_COPY=
endif

SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-dart,dart-fuse-operations,snax-bufferize,alloc-to-global,set-memory-space,dart-scheduler,set-memory-layout{gemm_layout=${LAYOUT}},realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=2},dart-layout-resolution,convert-dart-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space
SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-dart,dart-fuse-operations,snax-bufferize,alloc-to-global,set-memory-space,dart-scheduler,set-memory-layout,realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=2},dart-layout-resolution,convert-dart-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space

GEN_DATA_OPTS += --m=${SIZE_M}
GEN_DATA_OPTS += --n=${SIZE_N}
Expand Down
6 changes: 2 additions & 4 deletions benchmarks/dense_matmul/genbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def output_log(output_report) -> str:
result = "# Dense Matmul Benchmark Results\n\n"
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
result += f"This test was run at {dt_string}\n\n"
for layout, add_c in itertools.product(("cyclic", "banked"), (True, False)):
for layout, add_c in itertools.product(("cyclic",), (True, False)):
result += f"Results for a {layout} layout {'with add C' if add_c else ''} \n\n"
result += "| benchmark | layout | add C | M | N | K | plots | cycles | ideal | utilization |\n"
result += "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n"
Expand Down Expand Up @@ -149,9 +149,7 @@ def output_log_benchmark(

output_report: dict[str, dict] = {}

for size, layout, add_c in itertools.product(
sizes, ("cyclic", "banked"), (True, False)
):
for size, layout, add_c in itertools.product(sizes, ("cyclic",), (True, False)):
m, n, k = size

# plot:
Expand Down
304 changes: 83 additions & 221 deletions compiler/transforms/set_memory_layout.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, this diff is not useful, the entire file has been replaced by a new implementation

Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from dataclasses import dataclass
from math import ceil, prod

import numpy as np
from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.ir import Attribute
from xdsl.parser import MemRefType
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
Expand All @@ -11,265 +14,124 @@
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.str_enum import StrEnum
from xdsl.utils.hints import isa

from compiler.dialects import dart
from compiler.dialects.kernel import AddOp, QMacOp, RescaleOp
from compiler.dialects.snax import LayoutCast
from compiler.dialects.tsl import TiledStridedLayoutAttr
from compiler.ir.dart.access_pattern import Schedule, SchedulePattern
from compiler.ir.tsl import Stride, TiledStride, TiledStridedLayout


class AddMemoryLayoutSIMD(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: dart.ScheduleOp, rewriter: PatternRewriter):
# check if operation is dispatched via library call, as set by e.g.
# the dispatch-kernels pass

if op.accelerator is None:
return
else:
library_call = op.accelerator.data

# check for library call
if library_call == "snax_gemmx":
if not isinstance(op.body.block.first_op.body.block.first_op, RescaleOp):
return

shaped_operands: list[MemRefType] = [
operand.type
for operand in op.operands
if isinstance(operand.type, builtin.MemRefType)
]

m = shaped_operands[0].get_shape()[0]
n = shaped_operands[0].get_shape()[1]

if m == -1:
m = None
if n == -1:
n = None

tsl_input = TiledStridedLayoutAttr(
TiledStridedLayout(
[
TiledStride(
[
Stride(
256 * n // 8 if n else None, m // 8 if m else None
),
Stride(8, 8),
]
),
TiledStride([Stride(256, n // 8 if n else None), Stride(1, 8)]),
]
)
)

tsl_output = TiledStridedLayoutAttr(
TiledStridedLayout(
[
TiledStride(
[
Stride(
256 * n // 8 if n else None, m // 8 if m else None
),
Stride(8, 8),
]
),
TiledStride([Stride(256, n // 8 if n else None), Stride(1, 8)]),
]
)
)

# insert layout_cast ops
new_input_a = LayoutCast.from_type_and_target_layout(
op.inputs[0], tsl_input
)

new_output = LayoutCast.from_type_and_target_layout(
op.outputs[0], tsl_output
)

rewriter.insert_op([new_input_a, new_output], InsertPoint.before(op))

op.operands[0] = new_input_a.dest
op.operands[1] = new_output.dest


class GemmLayout(StrEnum):
cyclic = "cyclic"
banked = "banked"


@dataclass
class AddMemoryLayout(RewritePattern):
class AddCyclicMemoryLayout(RewritePattern):
"""
This class represents a rewrite pattern for adding memory layout to a
linalg operation. The implementation is very naive. It imposes a specific
memory layout on the input and output of the linalg operation dispatched
to snax_gemm by inserting layout_cast ops. In the future, the memory
layout will be selected in a more automatic way.
Automatically generates cyclic memory layouts for the operands of a schedule operation.
The layout is determined based on access patterns in the schedule.

Note: currently, only snax_gemm is supported.
- Dimensions accessed in the innermost loop receive the lowest stride.
- Other dimensions are assigned increasing strides in a contiguous manner.
- This approach can yield variations of formats such as NCHW, NHWC, etc., optimizing for efficient memory access.

Additionally, the pass supports tiled layouts, where data is first tiled before assigning contiguous strides.
This enables efficient memory layouts such as tiled block formats.
"""

gemm_layout: GemmLayout = GemmLayout.cyclic
# allow for tiled layouts?
tiled_layout: bool = False

@op_type_rewrite_pattern
def match_and_rewrite(self, op: dart.ScheduleOp, rewriter: PatternRewriter):
# check if operation is dispatched via library call, as set by e.g.
# the dispatch-kernels pass

if op.accelerator is None:
return
else:
library_call = op.accelerator.data
# do not alter pre-existing layouts
for operand in op.operands:
if isa(operand.type, MemRefType[Attribute]) and isinstance(
operand.type.layout, TiledStridedLayoutAttr
):
return

has_add_c = False
# get schedule from op
bounds = [x.value.data for x in op.bounds]
schedule = Schedule(SchedulePattern(bounds, x.data) for x in op.patterns)

# check for library call
if library_call == "snax_gemmx" or library_call == "snax_gemmx_stream":
# only do so for qmac kernels
generic_op = op.body.block.first_op
assert isinstance(generic_op, dart.GenericOp)
if not isinstance(generic_op.body.block.first_op, QMacOp):
return
# list to store newly generated operands for op
new_operands: list[LayoutCast] = []

if isinstance(generic_op.next_op, dart.GenericOp):
if isinstance(generic_op.next_op.body.block.first_op, AddOp):
# gemm
has_add_c = True
# a separate layout is determined for every operand
for operand, schedule in zip(op.operands, schedule):
assert isa(memref_type := operand.type, MemRefType[Attribute])

# the layout should be as static as the memref is. no more, no less
# get m, n, k
# start assigning contiguous, starting from stride = 1
current_stride = 1

shaped_operands: list[tuple[int, MemRefType]] = [
(index, op.type)
for index, op in enumerate(op.operands)
if isinstance(op.type, builtin.MemRefType)
# create a list to keep strides for every dimension.
# for non-tiled layouts, every dimension will be assigned 1 stride
# for tiled layouts, every dimension can be assigned multiple strides
strides: list[list[Stride]] = [
[] for _ in range(memref_type.get_num_dims())
]

m = shaped_operands[0][1].get_shape()[0]
n = shaped_operands[1][1].get_shape()[1]
k = shaped_operands[0][1].get_shape()[1]
# iterate over the columns of the schedule pattern in reversed order, to find out
# which dimension is accessed in the innermost loop of the operation
for schedule_bound, accesses in zip(
schedule.bounds[::-1], np.flip(schedule.pattern.A, axis=1).T
):
# normalize accesses to binary list
# this list will now have a 1 at the index of the dimension that is accessed
accesses = tuple(0 if x == 0 else 1 for x in accesses)

if m == -1:
m = None
if n == -1:
n = None
if k == -1:
k = None
if 1 not in accesses:
continue

# determine tile_stride = stride between two gemm tiles
match self.gemm_layout:
case GemmLayout.banked:
tile_stride = 256
case GemmLayout.cyclic:
tile_stride = 64
# find operand dimension that is accessed
accessed_dim = accesses.index(1)

tsl_input_a = TiledStridedLayoutAttr(
TiledStridedLayout(
[
TiledStride(
[
Stride(
tile_stride * k // 8 if k else None,
m // 8 if m else None,
),
Stride(8, 8),
]
),
TiledStride(
[Stride(tile_stride, k // 8 if k else None), Stride(1, 8)]
),
]
)
)
# we have determined the dimension and step for this layout
# now we must determine the bound.
# For non-tiled layouts, this bound will be equal to the operand shape.
# For tiled layouts, the bound is set equal to the schedule bound.

## tsl b has an offset of 64 to not collide with the banks of
### a (not yet - need aligned allocation for this)
tsl_input_b = TiledStridedLayoutAttr(
TiledStridedLayout(
[
TiledStride(
[Stride(tile_stride, k // 8 if k else None), Stride(1, 8)]
),
TiledStride(
[
Stride(
tile_stride * k // 8 if k else None,
n // 8 if n else None,
),
Stride(8, 8),
]
),
],
# offset=64,
)
)
# the existing bound of the current layout
existing_bound = prod(s.bound for s in strides[accessed_dim] if s.bound)

tsl_output = TiledStridedLayoutAttr(
TiledStridedLayout(
[
TiledStride(
[
Stride(
64 * n // 8 if n else None, m // 8 if m else None
),
Stride(8, 8),
]
),
TiledStride([Stride(64, n // 8 if n else None), Stride(1, 8)]),
]
# the remaining size of the operand dimension
size_remaining = ceil(
memref_type.get_shape()[accessed_dim] // existing_bound
)
)

# insert layout_cast ops
new_input_a = LayoutCast.from_type_and_target_layout(
op.inputs[0], tsl_input_a
)
# can we further tile the layout according to the remaining size?
if self.tiled_layout:
layout_bound = schedule_bound
else:
layout_bound = size_remaining

new_input_b = LayoutCast.from_type_and_target_layout(
op.inputs[1], tsl_input_b
)
# assign this current stride to the relevant operand dimension
strides[accesses.index(1)].insert(
0, Stride(current_stride, layout_bound)
)

new_output = LayoutCast.from_type_and_target_layout(
op.outputs[0], tsl_output
)
# increase current stride
current_stride = current_stride * layout_bound

rewriter.insert_op(
(new_input_a, new_input_b, new_output), InsertPoint.before(op)
)
layout = TiledStridedLayout(
[TiledStride(s) for s in strides]
).canonicalize()
tsl = TiledStridedLayoutAttr(layout)

if has_add_c:
rewriter.insert_op(
new_input_c := LayoutCast.from_type_and_target_layout(
op.inputs[2], tsl_output
),
InsertPoint.before(op),
)
op.operands[shaped_operands[0][0]] = new_input_a.dest
op.operands[shaped_operands[1][0]] = new_input_b.dest
op.operands[shaped_operands[2][0]] = new_input_c.dest
op.operands[shaped_operands[3][0]] = new_output.dest
else:
op.operands[shaped_operands[0][0]] = new_input_a.dest
op.operands[shaped_operands[1][0]] = new_input_b.dest
op.operands[shaped_operands[2][0]] = new_output.dest
new_operands.append(LayoutCast.from_type_and_target_layout(operand, tsl))

rewriter.insert_op(new_operands, InsertPoint.before(op))
for i, new_operand in enumerate(new_operands):
op.operands[i] = new_operand.dest


@dataclass(frozen=True)
class SetMemoryLayout(ModulePass):
name = "set-memory-layout"

gemm_layout: str = "cyclic"
tiled: bool | None = True

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
AddMemoryLayoutSIMD(), apply_recursively=False
).rewrite_module(op)
PatternRewriteWalker(
AddMemoryLayout(gemm_layout=GemmLayout(self.gemm_layout)),
apply_recursively=False,
).rewrite_module(op)
tiled = self.tiled if self.tiled is not None else True
PatternRewriteWalker(AddCyclicMemoryLayout(tiled_layout=tiled)).rewrite_module(
op
)
Loading