Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename stream/autoflow related stuff to dart (decoupled acceleration runtime tools) #344

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/dense_matmul/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ else
REMOVE_MEMREF_COPY=
endif

SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout{gemm_layout=${LAYOUT}},realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=2},convert-stream-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space
SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-dart,dart-fuse-operations,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout{gemm_layout=${LAYOUT}},realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=2},convert-dart-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space

GEN_DATA_OPTS += --m=${SIZE_M}
GEN_DATA_OPTS += --n=${SIZE_N}
Expand Down
4 changes: 2 additions & 2 deletions compiler/accelerators/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from compiler.accelerators.streamers import StreamerConfiguration
from compiler.accelerators.streamers.streamers import StreamerFlag, StreamerOpts
from compiler.dialects import accfg
from compiler.dialects.dart import StreamingRegionOpBase
from compiler.dialects.snax_stream import StreamerConfigurationAttr, StreamingRegionOp
from compiler.dialects.stream import StreamingRegionOpBase
from compiler.ir.stream import Template
from compiler.ir.dart.access_pattern import Template

c0_attr = builtin.IntegerAttr(0, builtin.IndexType())

Expand Down
6 changes: 3 additions & 3 deletions compiler/accelerators/snax_alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
StreamerConfiguration,
StreamerType,
)
from compiler.dialects import accfg, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern
from compiler.dialects import accfg, dart, snax_stream
from compiler.ir.dart.access_pattern import Template, TemplatePattern

default_streamer = StreamerConfiguration(
[
Expand Down Expand Up @@ -194,7 +194,7 @@ def generate_acc_op(self) -> accfg.AcceleratorOp:
return op

@staticmethod
def get_template(op: stream.StreamingRegionOpBase):
def get_template(op: dart.StreamingRegionOpBase):
template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3
template_bounds = (None, 4)
return Template(TemplatePattern(template_bounds, tp) for tp in template)
6 changes: 3 additions & 3 deletions compiler/accelerators/snax_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
StreamerConfiguration,
StreamerType,
)
from compiler.dialects import accfg, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern
from compiler.dialects import accfg, dart, snax_stream
from compiler.ir.dart.access_pattern import Template, TemplatePattern

default_streamer = StreamerConfiguration(
[
Expand Down Expand Up @@ -166,7 +166,7 @@ def lower_acc_await(acc_op: accfg.AcceleratorOp) -> Sequence[Operation]:
]

@staticmethod
def get_template(op: stream.StreamingRegionOpBase) -> Template:
def get_template(op: dart.StreamingRegionOpBase) -> Template:
M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6))
template = [
AffineMap(6, 0, (M * 8 + m, K * 8 + k)),
Expand Down
14 changes: 7 additions & 7 deletions compiler/accelerators/snax_gemmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
StreamerType,
)
from compiler.accelerators.streamers.streamers import StreamerOpts
from compiler.dialects import accfg, kernel, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern
from compiler.dialects import accfg, dart, kernel, snax_stream
from compiler.ir.dart.access_pattern import Template, TemplatePattern
from compiler.util.pack_bitlist import pack_bitlist

default_streamer = StreamerConfiguration(
Expand Down Expand Up @@ -179,7 +179,7 @@ def _generate_setup_vals(

ops_to_add: list[Operation] = []

assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)
assert isinstance(generic_op := op.body.block.first_op, dart.GenericOp)

if isinstance(qmac := generic_op.body.block.first_op, kernel.QMacOp):
# gemm
Expand Down Expand Up @@ -271,8 +271,8 @@ def _generate_setup_vals(
]

@staticmethod
def get_template(op: stream.StreamingRegionOpBase) -> Template:
assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)
def get_template(op: dart.StreamingRegionOpBase) -> Template:
assert isinstance(generic_op := op.body.block.first_op, dart.GenericOp)
if isinstance(generic_op.body.block.first_op, kernel.QMacOp):
# matmul
M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6))
Expand All @@ -283,7 +283,7 @@ def get_template(op: stream.StreamingRegionOpBase) -> Template:
]
template_bounds = (None, None, None, 8, 8, 8)

if isinstance(generic_op.next_op, stream.GenericOp):
if isinstance(generic_op.next_op, dart.GenericOp):
generic_op = generic_op.next_op
if isinstance(generic_op.body.block.first_op, kernel.AddOp):
# gemm, add c pattern that is equal to output pattern
Expand All @@ -299,7 +299,7 @@ def get_template(op: stream.StreamingRegionOpBase) -> Template:
]
template_bounds = (None, None, 8, 8)

if not isinstance(generic_op.next_op, stream.YieldOp):
if not isinstance(generic_op.next_op, dart.YieldOp):
raise RuntimeError("unsupported kernel")

return Template(TemplatePattern(template_bounds, tp) for tp in template)
23 changes: 12 additions & 11 deletions compiler/dialects/stream.py → compiler/dialects/dart.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
)

"""
Custom `stream` dialect, to simplify things in a more principled approach, including:
Custom `dart` dialect, heavily inspired by xDSL `stream` dialect, to simplify
things in a (hopefully) more principled approach, including:
- inherent support for tensors
- streams with value semantics
- no specified static bounds in access patterns: they are just affine maps
Expand All @@ -61,7 +62,7 @@ class StreamType(
Streams can only be read from, there is no distinction between readable/writeable streams.
"""

name = "stream.stream"
name = "dart.stream"

element_type: ParameterDef[_StreamTypeElement]

Expand Down Expand Up @@ -117,13 +118,13 @@ def __init__(


@irdl_op_definition
class StreamingRegionOp(StreamingRegionOpBase):
class OperationOp(StreamingRegionOpBase):
"""
A streaming region op that represents an unscheduled operation,
with streams mapping the iteration space to the operand indexing space.
"""

name = "stream.streaming_region"
name = "dart.operation"

def get_pattern_bounds_to_shapes_map(self) -> AffineMap:
"""
Expand Down Expand Up @@ -171,7 +172,7 @@ class ScheduleOp(StreamingRegionOpBase):
transformation took place.
"""

name = "stream.schedule"
name = "dart.schedule"

# The bounds of the iteration space of the schedule
bounds = prop_def(ParameterDef[ArrayAttr[IntegerAttr[IndexType]]])
Expand Down Expand Up @@ -219,7 +220,7 @@ class AccessPatternOp(StreamingRegionOpBase):
layout resolution, with streams mapping the iteration space to memory.
"""

name = "stream.access_pattern"
name = "dart.access_pattern"

# The bounds of the iteration space of the schedule
bounds = prop_def(ParameterDef[ArrayAttr[IntegerAttr[IndexType]]])
Expand Down Expand Up @@ -251,7 +252,7 @@ def __init__(

@irdl_op_definition
class YieldOp(AbstractYieldOperation[Attribute]):
name = "stream.yield"
name = "dart.yield"

traits = traits_def(IsTerminator())

Expand All @@ -263,7 +264,7 @@ class GenericOp(IRDLOperation):
Indexing maps / iterators are not relevant, so they are not included.
"""

name = "stream.generic"
name = "dart.generic"

# inputs can be streams or integers
inputs = var_operand_def()
Expand Down Expand Up @@ -294,10 +295,10 @@ def __init__(
)


Stream = Dialect(
"stream",
Dart = Dialect(
"dart",
[
StreamingRegionOp,
OperationOp,
ScheduleOp,
AccessPatternOp,
GenericOp,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self, TypeVar, overload
from xdsl.ir.affine import AffineDimExpr, AffineMap

from compiler.ir.autoflow.affine_transform import AffineTransform
from compiler.ir.dart.affine_transform import AffineTransform


@dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from compiler.ir.stream import Schedule, Template
from compiler.ir.dart.access_pattern import Schedule, Template


def scheduler(template: Template, schedule: Schedule) -> Schedule:
Expand Down
2 changes: 0 additions & 2 deletions compiler/ir/stream/__init__.py

This file was deleted.

18 changes: 10 additions & 8 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from xdsl.xdsl_opt_main import xDSLOptMain

from compiler.dialects.accfg import ACCFG
from compiler.dialects.dart import Dart
from compiler.dialects.kernel import Kernel
from compiler.dialects.snax import Snax
from compiler.dialects.snax_stream import SnaxStream
from compiler.dialects.stream import Stream
from compiler.dialects.test.debug import Debug
from compiler.dialects.tsl import TSL
from compiler.transforms.accfg_config_overlap import AccfgConfigOverlapPass
Expand All @@ -18,20 +18,20 @@
from compiler.transforms.backend.postprocess_mlir import PostprocessPass
from compiler.transforms.clear_memory_space import ClearMemorySpace
from compiler.transforms.convert_accfg_to_csr import ConvertAccfgToCsrPass
from compiler.transforms.convert_dart_to_snax_stream import ConvertDartToSnaxStream
from compiler.transforms.convert_kernel_to_linalg import ConvertKernelToLinalg
from compiler.transforms.convert_linalg_to_accfg import (
ConvertLinalgToAccPass,
TraceStatesPass,
)
from compiler.transforms.convert_linalg_to_kernel import ConvertLinalgToKernel
from compiler.transforms.convert_linalg_to_stream import ConvertLinalgToStream
from compiler.transforms.convert_stream_to_snax_stream import ConvertStreamToSnaxStream
from compiler.transforms.convert_tosa_to_kernel import ConvertTosaToKernelPass
from compiler.transforms.dart.convert_linalg_to_dart import ConvertLinalgToDart
from compiler.transforms.dart.dart_fuse_operations import DartFuseOperationsPass
from compiler.transforms.dispatch_kernels import DispatchKernels
from compiler.transforms.dispatch_regions import DispatchRegions
from compiler.transforms.frontend.preprocess_mlir import PreprocessPass
from compiler.transforms.frontend.preprocess_mlperf_tiny import PreprocessMLPerfTiny
from compiler.transforms.fuse_streaming_regions import FuseStreamingRegions
from compiler.transforms.insert_accfg_op import InsertAccOp
from compiler.transforms.insert_sync_barrier import InsertSyncBarrier
from compiler.transforms.memref_to_snax import MemrefToSNAX
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self.ctx.load_dialect(ACCFG)
self.ctx.load_dialect(SnaxStream)
self.ctx.load_dialect(Debug)
self.ctx.load_dialect(Stream)
self.ctx.load_dialect(Dart)
super().register_pass(DispatchKernels.name, lambda: DispatchKernels)
super().register_pass(SetMemorySpace.name, lambda: SetMemorySpace)
super().register_pass(SetMemoryLayout.name, lambda: SetMemoryLayout)
Expand All @@ -104,7 +104,7 @@ def __init__(
AccfgConfigOverlapPass.name, lambda: AccfgConfigOverlapPass
)
super().register_pass(
ConvertStreamToSnaxStream.name, lambda: ConvertStreamToSnaxStream
ConvertDartToSnaxStream.name, lambda: ConvertDartToSnaxStream
)
super().register_pass(ReuseMemrefAllocs.name, lambda: ReuseMemrefAllocs)
super().register_pass(RemoveMemrefCopyPass.name, lambda: RemoveMemrefCopyPass)
Expand All @@ -120,9 +120,11 @@ def __init__(
super().register_pass(DebugToFuncPass.name, lambda: DebugToFuncPass)
super().register_pass(PreprocessMLPerfTiny.name, lambda: PreprocessMLPerfTiny)
super().register_pass(AddMcycleAroundLaunch.name, lambda: AddMcycleAroundLaunch)
super().register_pass(ConvertLinalgToStream.name, lambda: ConvertLinalgToStream)
super().register_pass(ConvertLinalgToDart.name, lambda: ConvertLinalgToDart)
super().register_pass(SnaxBufferize.name, lambda: SnaxBufferize)
super().register_pass(FuseStreamingRegions.name, lambda: FuseStreamingRegions)
super().register_pass(
DartFuseOperationsPass.name, lambda: DartFuseOperationsPass
)
super().register_pass(AllocToGlobalPass.name, lambda: AllocToGlobalPass)
super().register_pass(PreprocessPass.name, lambda: PreprocessPass)
super().register_pass(PostprocessPass.name, lambda: PostprocessPass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from compiler.accelerators.registry import AcceleratorRegistry
from compiler.accelerators.snax import SNAXStreamer
from compiler.accelerators.util import find_accelerator_op
from compiler.dialects import snax_stream, stream
from compiler.ir.autoflow.affine_transform import AffineTransform
from compiler.ir.stream import Schedule, SchedulePattern, scheduler
from compiler.ir.stream.access_pattern import Template
from compiler.dialects import dart, snax_stream
from compiler.ir.dart.access_pattern import Schedule, SchedulePattern, Template
from compiler.ir.dart.affine_transform import AffineTransform
from compiler.ir.dart.scheduler import scheduler


def get_accelerator_info(op: stream.StreamingRegionOpBase) -> Template:
def get_accelerator_info(op: dart.StreamingRegionOpBase) -> Template:
assert op.accelerator is not None

# Go and fetch the accelerator op
Expand Down Expand Up @@ -51,9 +51,7 @@ class AutoflowScheduler(RewritePattern):
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: stream.StreamingRegionOp, rewriter: PatternRewriter
):
def match_and_rewrite(self, op: dart.OperationOp, rewriter: PatternRewriter):
template = get_accelerator_info(op)

# Make sure the operands are memrefs
Expand All @@ -69,7 +67,7 @@ def match_and_rewrite(
)
schedule = scheduler(template, schedule)

schedule_op = stream.ScheduleOp(
schedule_op = dart.ScheduleOp(
op.inputs,
op.outputs,
ArrayAttr([AffineMapAttr(s.pattern.to_affine_map()) for s in schedule]),
Expand All @@ -92,7 +90,7 @@ class LayoutResolution(RewritePattern):
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stream.ScheduleOp, rewriter: PatternRewriter):
def match_and_rewrite(self, op: dart.ScheduleOp, rewriter: PatternRewriter):
bounds = [x.value.data for x in op.bounds.data]
schedule = Schedule(
SchedulePattern(bounds, pattern.data) for pattern in op.patterns
Expand Down Expand Up @@ -147,7 +145,7 @@ def generate_one_list(n: int, i: int):

new_patterns = ArrayAttr([AffineMapAttr(map) for map in access_patterns])

access_pattern_op = stream.AccessPatternOp(
access_pattern_op = dart.AccessPatternOp(
new_inputs,
new_outputs,
new_patterns,
Expand All @@ -170,7 +168,7 @@ class ConvertStreamToSnaxStreamPattern(RewritePattern):
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stream.AccessPatternOp, rewriter: PatternRewriter):
def match_and_rewrite(self, op: dart.AccessPatternOp, rewriter: PatternRewriter):
template = get_accelerator_info(op)

snax_stride_patterns: list[snax_stream.StridePattern] = []
Expand Down Expand Up @@ -324,8 +322,8 @@ def match_and_rewrite(self, op: stream.AccessPatternOp, rewriter: PatternRewrite


@dataclass(frozen=True)
class ConvertStreamToSnaxStream(ModulePass):
name = "convert-stream-to-snax-stream"
class ConvertDartToSnaxStream(ModulePass):
name = "convert-dart-to-snax-stream"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(AutoflowScheduler()).rewrite_module(op)
Expand Down
Empty file.
Loading
Loading