Skip to content

Commit

Permalink
initialize classes for the scheduling mechanism (#288)
Browse files Browse the repository at this point in the history
* clean up scheduling mechanism

* formatting

* fix filechecks

* rename to stream

* more better cleaner

* typo

* only implement tile and rotate for schedules

* fix tests
  • Loading branch information
jorendumoulin authored Oct 16, 2024
1 parent b9f3d65 commit e65765e
Show file tree
Hide file tree
Showing 10 changed files with 577 additions and 176 deletions.
6 changes: 2 additions & 4 deletions compiler/accelerators/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from xdsl.dialects.builtin import IntAttr, i32
from xdsl.dialects.scf import Condition, While, Yield
from xdsl.ir import Operation, OpResult, SSAValue
from xdsl.ir.affine import AffineMap

from compiler.accelerators.accelerator import Accelerator
from compiler.accelerators.streamers import StreamerConfiguration
from compiler.accelerators.streamers.streamers import StreamerFlag, StreamerOpts
from compiler.dialects import accfg, stream
from compiler.dialects.snax_stream import StreamerConfigurationAttr, StreamingRegionOp
from compiler.ir.stream import Template

c0_attr = builtin.IntegerAttr(0, builtin.IndexType())

Expand Down Expand Up @@ -261,9 +261,7 @@ def get_streamer_launch_dict(self, base_addr) -> tuple[int, dict[str, int]]:

@staticmethod
@abstractmethod
def get_template(
op: stream.StreamingRegionOp,
) -> tuple[Sequence[AffineMap], Sequence[int | None]]:
def get_template(op: stream.StreamingRegionOp) -> Template:
"""
Get the template for this acelerator to schedule a given
stream.streaming_region operation.
Expand Down
3 changes: 2 additions & 1 deletion compiler/accelerators/snax_alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
StreamerType,
)
from compiler.dialects import accfg, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern

default_streamer = StreamerConfiguration(
[
Expand Down Expand Up @@ -194,4 +195,4 @@ def generate_acc_op(self) -> accfg.AcceleratorOp:
def get_template(op: stream.StreamingRegionOp):
template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3
template_bounds = (None, 4)
return template, template_bounds
return Template(TemplatePattern(template_bounds, tp) for tp in template)
7 changes: 3 additions & 4 deletions compiler/accelerators/snax_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
StreamerType,
)
from compiler.dialects import accfg, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern

default_streamer = StreamerConfiguration(
[
Expand Down Expand Up @@ -163,14 +164,12 @@ def lower_acc_await(acc_op: accfg.AcceleratorOp) -> Sequence[Operation]:
]

@staticmethod
def get_template(
op: stream.StreamingRegionOp,
) -> tuple[Sequence[AffineMap], Sequence[int | None]]:
def get_template(op: stream.StreamingRegionOp) -> Template:
M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6))
template = [
AffineMap(6, 0, (M * 8 + m, K * 8 + k)),
AffineMap(6, 0, (K * 8 + k, N * 8 + n)),
AffineMap(6, 0, (M * 8 + m, N * 8 + n)),
]
template_bounds = (None, None, None, 8, 8, 8)
return template, template_bounds
return Template(TemplatePattern(template_bounds, tp) for tp in template)
7 changes: 3 additions & 4 deletions compiler/accelerators/snax_gemmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from compiler.accelerators.streamers.streamers import StreamerOpts
from compiler.dialects import accfg, kernel, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern
from compiler.util.pack_bitlist import pack_bitlist

default_streamer = StreamerConfiguration(
Expand Down Expand Up @@ -262,9 +263,7 @@ def _generate_setup_vals(
]

@staticmethod
def get_template(
op: stream.StreamingRegionOp,
) -> tuple[Sequence[AffineMap], Sequence[int | None]]:
def get_template(op: stream.StreamingRegionOp) -> Template:
assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)
if isinstance(generic_op.body.block.first_op, kernel.QMacOp):
# matmul
Expand Down Expand Up @@ -295,4 +294,4 @@ def get_template(
if not isinstance(generic_op.next_op, stream.YieldOp):
raise RuntimeError("unsupported kernel")

return template, template_bounds
return Template(TemplatePattern(template_bounds, tp) for tp in template)
2 changes: 2 additions & 0 deletions compiler/ir/stream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .access_pattern import *
from .scheduler import *
236 changes: 236 additions & 0 deletions compiler/ir/stream/access_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from abc import ABC
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Generic

from typing_extensions import Self, TypeVar, deprecated, overload
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap

from compiler.util.canonicalize_affine import canonicalize_map


@dataclass(frozen=True)
class AccessPattern(ABC):
"""
Class specifying an access pattern for a single operand.
This is specified by bounds and a pattern given as an AffineMap.
"""

bounds: tuple[int | None, ...]
pattern: AffineMap

def __init__(self, bounds: Sequence[int | None], pattern: AffineMap):
# Convert bounds to a tuple
bounds = tuple(bounds)

# Perform validations
if len(bounds) != pattern.num_dims:
raise ValueError(
"The number of bounds should be equal to the dimension of the pattern"
)

if pattern.num_symbols > 0:
raise ValueError("Symbols in the pattern are not supported")

# Canonicalize the pattern
new_pattern = canonicalize_map(pattern)

# Assign attributes using object.__setattr__ due to frozen=True
object.__setattr__(self, "bounds", bounds)
object.__setattr__(self, "pattern", new_pattern)

@property
def num_dims(self):
return len(self.bounds)

def disable_dims(self, dim: int) -> Self:
"""
Returns an affine map with the leftmost `dim` dimensions set to 0
For example:
(d0, d1, d2) -> d0 + d1 + d2
For `dim` = 1, will return:
(d1, d2) -> d1 + d2
For `dim` = 2, will return:
(d2) -> d2
"""
new_pattern = self.pattern.replace_dims_and_symbols(
tuple(AffineConstantExpr(0) for _ in range(dim))
+ tuple(AffineDimExpr(i) for i in range(self.num_dims - dim)),
[],
self.num_dims - dim,
0,
)
return type(self)(self.bounds[dim:], new_pattern)


@dataclass(frozen=True)
class SchedulePattern(AccessPattern):
"""
A schedule pattern is a pattern for a schedule of an operation.
Schedule patterns are constrained to have static bounds.
"""

# constrain bounds to only be int
bounds: tuple[int, ...]

def __init__(self, bounds: Sequence[int], pattern: AffineMap):
if any(bound is None or bound <= 0 for bound in bounds):
raise ValueError(
"All bounds must be static, strictly positive integers for a schedule"
)

super().__init__(bounds, pattern)

def rotate(self, dim: int) -> Self:
"""
Returns a new schedule with the leftmost `dim` dimensions rotated
For example:
(d0, d1, d2) -> 1 * d0 + 2 * d1 + 3 * d2
For `dim` = 3, will return:
(d0, d1, d2) -> 3 * d0 + 1 * d1 + 2 * d2
For `dim` = 2, will return:
(d0, d1, d2) -> 2 * d0 + 1 * d1 + 3 * d2
return AccessPattern()
"""

# Rotate in the following manner:
# (0, 1, 2, 3, ..., dim-1, dim, dim+1, ..., num_dims - 1)
# --> (1, 2, 3, ..., dim-1, 0, dim, dim+1, ..., num_dims - 1)

new_dims = tuple(AffineDimExpr(i) for i in range(self.num_dims))
new_dims = new_dims[1:dim] + new_dims[:1] + new_dims[dim:]
new_bounds = self.bounds[1:dim] + self.bounds[:1] + self.bounds[dim:]

new_pattern = self.pattern.replace_dims_and_symbols(
new_dims, [], self.num_dims, 0
)
return type(self)(new_bounds, new_pattern)

def tile_dim(self, dim: int, template_bound: int) -> Self:
"""
Returns a new access pattern with the `dim` dimension split up into two
This translates to creating two for loops with adjusted bounds from one for loop
For example:
(d0, d1, d2) -> d0 + d1 + d2
For `dim` = 1, `template_bound` = 2:
(d0, d1, d2, d3) -> d0 + 2 * d1 + d2 + d3
The bounds are split in similar fashion:
For example:
[2, 8, 2]
For `dim` = 1, `template_bound` = 2:
[2, 4, 2, 2]
"""
transform_map = AffineMap(
num_dims=self.num_dims + 1,
num_symbols=0,
# (d0, d1, d2, ..., dim-1) -> (d0, d1, d2, ..., dim-1)
results=tuple(AffineDimExpr(i) for i in range(dim))
# (dim) -> (template_bound * dim + dim + 1)
+ (AffineDimExpr(dim) * template_bound + AffineDimExpr(dim + 1),)
# (dim + 1, dim + 2, ...) -> (dim + 2, dim + 3, dim + 3)
+ tuple(AffineDimExpr(i + 1) for i in range(dim + 1, self.num_dims)),
)
new_pattern = self.pattern.compose(transform_map)
bound_to_tile = self.bounds[dim]
tiled_bound = bound_to_tile // template_bound
new_bounds = (
self.bounds[:dim] + (tiled_bound, template_bound) + self.bounds[dim + 1 :]
)

return type(self)(new_bounds, new_pattern)


@dataclass(frozen=True)
class TemplatePattern(AccessPattern):
"""
Template pattern is a pattern for an accelerator template.
Templates should not be transformed through either tiling/rotating/others.
"""

def __init__(self, bounds: Sequence[int], pattern: AffineMap):
super().__init__(bounds, pattern)

def matches(self, sp: SchedulePattern):
"""
Check if a given schedule pattern matches this
template pattern.
"""
if sp.num_dims != self.num_dims:
return False
if sp.pattern != self.pattern:
return False
return True


P = TypeVar("P", bound=AccessPattern)


class PatternCollection(Sequence[P], Generic[P], ABC):
"""
Abstract base class for collections of AccessPatterns.
Provides common methods and properties for Schedule and Template classes.
"""

def __init__(self, patterns: Iterable[P]):
self._patterns = tuple(patterns)

@overload
def __getitem__(self, index: int) -> P:
...

@overload
def __getitem__(self, index: slice) -> tuple[P]:
...

def __getitem__(self, index: int | slice):
return self._patterns[index]

def __len__(self) -> int:
return len(self._patterns)

def __iter__(self) -> Iterator[P]:
return iter(self._patterns)

@property
@deprecated("only valid in trivial cases")
def num_dims(self) -> int:
return self[0].num_dims

def rotate(self, dim: int) -> Self:
return type(self)(sp.rotate(dim) for sp in self)

def disable_dims(self, dim: int) -> Self:
return type(self)(sp.disable_dims(dim) for sp in self)

def tile_dim(self, dim: int, template_bound: int) -> Self:
return type(self)(sp.tile_dim(dim, template_bound) for sp in self)


class Schedule(PatternCollection[SchedulePattern]):
"""
A schedule consisting of multiple SchedulePatterns for different operands.
"""

...


class Template(PatternCollection[TemplatePattern]):
"""
A template consisting of multiple TemplatePatterns for different operands.
"""

def matches(self, schedule: Schedule):
if len(schedule) != len(self):
return False
for sp, tp in zip(schedule, self):
if not tp.matches(sp):
return False
return True
43 changes: 43 additions & 0 deletions compiler/ir/stream/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from compiler.ir.stream.access_pattern import Schedule, Template


def scheduler(template: Template, schedule: Schedule) -> Schedule:
for i in range(template.num_dims):
# i = 0: look at the last dimension
# i = 1: look at the second to last dimension
template_dim = template.num_dims - i - 1
schedule_dim = schedule.num_dims - i - 1
match = False

# maximum number of rotations
for _ in range(schedule_dim + 1):
# check if there is a match
template_check = template.disable_dims(template_dim)
schedule_check = schedule.disable_dims(schedule_dim)

if template_check.matches(schedule_check):
match = True
break

# else rotate the for loops
schedule = schedule.rotate(schedule_dim + 1)

if not match:
raise RuntimeError("failed to match template and schedule")

# now, check bounds and design potential transformation map
if not (template_bound := template[0].bounds[template_dim]):
# nothing to worry about, continue to next dim
continue

schedule_bound = schedule[0].bounds[schedule_dim]

if schedule_bound < template_bound:
# need to apply padding
raise NotImplementedError("padding not supported")
elif schedule_bound >= template_bound:
# need to split up the schedule
assert schedule_bound % template_bound == 0
schedule = schedule.tile_dim(schedule_dim, template_bound)

return schedule
Loading

0 comments on commit e65765e

Please sign in to comment.