Skip to content

Commit

Permalink
add scheduler tests (#289)
Browse files Browse the repository at this point in the history
* add scheduler tests

sq

s

s

* fix tests

* remove prints

* capitalize
  • Loading branch information
jorendumoulin authored Oct 17, 2024
1 parent e65765e commit 310b2cc
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 80 deletions.
70 changes: 63 additions & 7 deletions compiler/ir/stream/access_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def rotate(self, dim: int) -> Self:
# --> (1, 2, 3, ..., dim-1, 0, dim, dim+1, ..., num_dims - 1)

new_dims = tuple(AffineDimExpr(i) for i in range(self.num_dims))
new_dims = new_dims[1:dim] + new_dims[:1] + new_dims[dim:]
new_dims = new_dims[dim - 1 : dim] + new_dims[: dim - 1] + new_dims[dim:]
new_bounds = self.bounds[1:dim] + self.bounds[:1] + self.bounds[dim:]

new_pattern = self.pattern.replace_dims_and_symbols(
Expand Down Expand Up @@ -146,6 +146,24 @@ def tile_dim(self, dim: int, template_bound: int) -> Self:

return type(self)(new_bounds, new_pattern)

def add_dim(self) -> Self:
"""
Returns a new schedule pattern with an extra empty dimension inserted.
For example:
(d0, d1) -> d0 + d1
Will result in:
(d0, d1, d2) -> d1 + d2
"""
new_pattern = self.pattern
transform_map = AffineMap(
num_dims=self.num_dims + 1,
num_symbols=0,
results=tuple(AffineDimExpr(i + 1) for i in range(self.num_dims)),
)
new_pattern = self.pattern.compose(transform_map)
new_bounds = (1,) + self.bounds
return type(self)(new_bounds, new_pattern)


@dataclass(frozen=True)
class TemplatePattern(AccessPattern):
Expand All @@ -155,7 +173,7 @@ class TemplatePattern(AccessPattern):
Templates should not be transformed through either tiling/rotating/others.
"""

def __init__(self, bounds: Sequence[int], pattern: AffineMap):
def __init__(self, bounds: Sequence[int | None], pattern: AffineMap):
super().__init__(bounds, pattern)

def matches(self, sp: SchedulePattern):
Expand Down Expand Up @@ -199,27 +217,65 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[P]:
return iter(self._patterns)

def __eq__(self, other: object) -> bool:
if not isinstance(other, PatternCollection):
return False
return self._patterns == other._patterns

@property
@deprecated("only valid in trivial cases")
def num_dims(self) -> int:
return self[0].num_dims

def rotate(self, dim: int) -> Self:
return type(self)(sp.rotate(dim) for sp in self)
@property
def max_dim(self) -> int:
return max(pattern.num_dims for pattern in self._patterns)

def disable_dims(self, dim: int) -> Self:
return type(self)(sp.disable_dims(dim) for sp in self)

def tile_dim(self, dim: int, template_bound: int) -> Self:
return type(self)(sp.tile_dim(dim, template_bound) for sp in self)
def clear_unused_dims(self, bounds: tuple[int] | None = None) -> Self:
"""
Returns a PatternCollection of which all dimensions that have bound 1 are cleared.
Optionally, specify custom bounds.
"""
if bounds is None:
pattern_bounds = self._patterns[0].bounds
else:
pattern_bounds = bounds
unused_dims = tuple(i for i, bound in enumerate(pattern_bounds) if bound == 1)
dim_substitutions = []
unused_counter = 0
for dim in range(self.num_dims):
if dim not in unused_dims:
dim_substitutions.append(AffineDimExpr(dim - unused_counter))
else:
dim_substitutions.append(AffineConstantExpr(0))
unused_counter += 1
return type(self)(
type(self._patterns[0])(
tuple(bound for bound in pattern_bounds if bound != 1),
sp.pattern.replace_dims_and_symbols(
dim_substitutions, [], self.num_dims - unused_counter, 0
),
)
for sp in self
)


class Schedule(PatternCollection[SchedulePattern]):
"""
A schedule consisting of multiple SchedulePatterns for different operands.
"""

...
def rotate(self, dim: int) -> Self:
return type(self)(sp.rotate(dim) for sp in self)

def tile_dim(self, dim: int, template_bound: int) -> Self:
return type(self)(sp.tile_dim(dim, template_bound) for sp in self)

def add_dim(self) -> Self:
return type(self)(sp.add_dim() for sp in self)


class Template(PatternCollection[TemplatePattern]):
Expand Down
2 changes: 1 addition & 1 deletion compiler/ir/stream/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from compiler.ir.stream.access_pattern import Schedule, Template
from compiler.ir.stream import Schedule, Template


def scheduler(template: Template, schedule: Schedule) -> Schedule:
Expand Down
3 changes: 2 additions & 1 deletion compiler/transforms/convert_stream_to_snax_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def generate_one_list(n: int, i: int):
3,
snax_stream.StridePattern(
upper_bounds=snax_stride_patterns[3].upper_bounds,
temporal_strides=[0] * 3,
temporal_strides=[0]
* len(snax_stride_patterns[3].upper_bounds),
spatial_strides=[8],
),
)
Expand Down
34 changes: 30 additions & 4 deletions compiler/util/canonicalize_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def canonicalize_addition(expr: AffineBinaryOpExpr) -> AffineExpr:
folding the op if both operands are constant
omitting a + 0
ordering the operands by their dimension
changing (a + b) + c into a + (b + c)
"""
# always put the constant on rhs
assert expr.kind is AffineBinaryOpKind.Add
Expand All @@ -56,7 +57,18 @@ def canonicalize_addition(expr: AffineBinaryOpExpr) -> AffineExpr:
dim_rhs = get_dim(expr.rhs)
if dim_rhs is not None:
if dim_lhs is None or dim_lhs > dim_rhs:
return expr.rhs + expr.lhs
new_expr = expr.rhs + expr.lhs
# TODO: make __add__ typing more specific in xdsl to avoid this
assert isinstance(new_expr, AffineBinaryOpExpr)
expr = new_expr
# turn (a + b) + c into a + (b + c)
if (
isinstance(expr.lhs, AffineBinaryOpExpr)
and expr.lhs.kind is AffineBinaryOpKind.Add
):
new_expr = expr.lhs.lhs + (expr.lhs.rhs + expr.rhs)
assert isinstance(new_expr, AffineBinaryOpExpr)
expr = new_expr
return expr


Expand All @@ -67,6 +79,7 @@ def canonicalize_multiplication(expr: AffineBinaryOpExpr) -> AffineExpr:
folding the op if both operands are constant
omitting a * 1
ordering the operands by their dimension
(a + b) * cst = (a * cst) + (b * cst)
"""
# always put the constant on rhs
assert expr.kind is AffineBinaryOpKind.Mul
Expand All @@ -77,12 +90,20 @@ def canonicalize_multiplication(expr: AffineBinaryOpExpr) -> AffineExpr:
return AffineConstantExpr(expr.lhs.value * expr.rhs.value)
else:
# move constant to rhs
return AffineBinaryOpExpr(expr.kind, expr.rhs, expr.lhs)
expr = AffineBinaryOpExpr(expr.kind, expr.rhs, expr.lhs)
if isinstance(expr.rhs, AffineConstantExpr):
# rhs is constant, lhs is not
# multiplication by 1 can be omitted
if expr.rhs.value == 1:
return expr.lhs
# turn (a + b) * cst into (a * cst) + (b * cst)
if (
isinstance(expr.lhs, AffineBinaryOpExpr)
and expr.lhs.kind is AffineBinaryOpKind.Add
):
new_expr = (expr.lhs.lhs * expr.rhs) + (expr.lhs.rhs * expr.rhs)
assert isinstance(new_expr, AffineBinaryOpExpr)
expr = new_expr
return expr


Expand Down Expand Up @@ -128,10 +149,15 @@ def canonicalize_binary_op(expr: AffineBinaryOpExpr) -> AffineExpr:


def canonicalize_expr(expr: AffineExpr) -> AffineExpr:
new_expr = expr

if isinstance(expr, AffineBinaryOpExpr):
return canonicalize_binary_op(expr)
new_expr = canonicalize_binary_op(expr)

return expr
if new_expr == expr:
return new_expr

return canonicalize_expr(new_expr)


# helper function to canonicalize affine maps
Expand Down
8 changes: 5 additions & 3 deletions tests/dialects/test_tsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def test_tsl_attr_constructor(example_tsl_attr):
def test_tsl_attr_get_affine(example_tsl_attr):
tsl = example_tsl_attr
map = canonicalize_map(tsl.get_affine_map())
assert map == AffineMap.from_callable(
lambda d0, d1: (
(((((d0 // 4) * 32) + ((d0 % 4) * 4)) + ((d1 // 4) * 16)) + (d1 % 4)),
assert map == canonicalize_map(
AffineMap.from_callable(
lambda d0, d1: (
(((((d0 // 4) * 32) + ((d0 % 4) * 4)) + ((d1 // 4) * 16)) + (d1 % 4)),
)
)
)
Loading

0 comments on commit 310b2cc

Please sign in to comment.