From e65765e4114f2eccb101dbcb865fe740d5008186 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Wed, 16 Oct 2024 14:58:27 +0200 Subject: [PATCH] initialize classes for the scheduling mechanism (#288) * clean up scheduling mechanism * formatting * fix filechecks * rename to stream * more better cleaner * typo * only implement tile and rotate for schedules * fix tests --- compiler/accelerators/snax.py | 6 +- compiler/accelerators/snax_alu.py | 3 +- compiler/accelerators/snax_gemm.py | 7 +- compiler/accelerators/snax_gemmx.py | 7 +- compiler/ir/stream/__init__.py | 2 + compiler/ir/stream/access_pattern.py | 236 +++++++++++++++ compiler/ir/stream/scheduler.py | 43 +++ .../convert_stream_to_snax_stream.py | 17 +- compiler/transforms/schedule_memref_linalg.py | 156 ---------- tests/ir/stream/test_access_pattern.py | 276 ++++++++++++++++++ 10 files changed, 577 insertions(+), 176 deletions(-) create mode 100644 compiler/ir/stream/__init__.py create mode 100644 compiler/ir/stream/access_pattern.py create mode 100644 compiler/ir/stream/scheduler.py delete mode 100644 compiler/transforms/schedule_memref_linalg.py create mode 100644 tests/ir/stream/test_access_pattern.py diff --git a/compiler/accelerators/snax.py b/compiler/accelerators/snax.py index be8c418b..7f31ad20 100644 --- a/compiler/accelerators/snax.py +++ b/compiler/accelerators/snax.py @@ -6,13 +6,13 @@ from xdsl.dialects.builtin import IntAttr, i32 from xdsl.dialects.scf import Condition, While, Yield from xdsl.ir import Operation, OpResult, SSAValue -from xdsl.ir.affine import AffineMap from compiler.accelerators.accelerator import Accelerator from compiler.accelerators.streamers import StreamerConfiguration from compiler.accelerators.streamers.streamers import StreamerFlag, StreamerOpts from compiler.dialects import accfg, stream from compiler.dialects.snax_stream import StreamerConfigurationAttr, StreamingRegionOp +from compiler.ir.stream import Template c0_attr = builtin.IntegerAttr(0, builtin.IndexType()) @@ -261,9 +261,7 @@ def get_streamer_launch_dict(self, base_addr) -> tuple[int, dict[str, int]]: @staticmethod @abstractmethod - def get_template( - op: stream.StreamingRegionOp, - ) -> tuple[Sequence[AffineMap], Sequence[int | None]]: + def get_template(op: stream.StreamingRegionOp) -> Template: """ Get the template for this acelerator to schedule a given stream.streaming_region operation. diff --git a/compiler/accelerators/snax_alu.py b/compiler/accelerators/snax_alu.py index 1403e289..93fe5d64 100644 --- a/compiler/accelerators/snax_alu.py +++ b/compiler/accelerators/snax_alu.py @@ -18,6 +18,7 @@ StreamerType, ) from compiler.dialects import accfg, snax_stream, stream +from compiler.ir.stream import Template, TemplatePattern default_streamer = StreamerConfiguration( [ @@ -194,4 +195,4 @@ def generate_acc_op(self) -> accfg.AcceleratorOp: def get_template(op: stream.StreamingRegionOp): template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3 template_bounds = (None, 4) - return template, template_bounds + return Template(TemplatePattern(template_bounds, tp) for tp in template) diff --git a/compiler/accelerators/snax_gemm.py b/compiler/accelerators/snax_gemm.py index 53b107b7..5e6cb0af 100644 --- a/compiler/accelerators/snax_gemm.py +++ b/compiler/accelerators/snax_gemm.py @@ -14,6 +14,7 @@ StreamerType, ) from compiler.dialects import accfg, snax_stream, stream +from compiler.ir.stream import Template, TemplatePattern default_streamer = StreamerConfiguration( [ @@ -163,9 +164,7 @@ def lower_acc_await(acc_op: accfg.AcceleratorOp) -> Sequence[Operation]: ] @staticmethod - def get_template( - op: stream.StreamingRegionOp, - ) -> tuple[Sequence[AffineMap], Sequence[int | None]]: + def get_template(op: stream.StreamingRegionOp) -> Template: M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6)) template = [ AffineMap(6, 0, (M * 8 + m, K * 8 + k)), @@ -173,4 +172,4 @@ def get_template( AffineMap(6, 0, (M * 8 + m, N * 8 + n)), ] template_bounds = (None, None, None, 8, 8, 8) - return template, template_bounds + return Template(TemplatePattern(template_bounds, tp) for tp in template) diff --git a/compiler/accelerators/snax_gemmx.py b/compiler/accelerators/snax_gemmx.py index 0b4243d3..9a3d997e 100644 --- a/compiler/accelerators/snax_gemmx.py +++ b/compiler/accelerators/snax_gemmx.py @@ -19,6 +19,7 @@ ) from compiler.accelerators.streamers.streamers import StreamerOpts from compiler.dialects import accfg, kernel, snax_stream, stream +from compiler.ir.stream import Template, TemplatePattern from compiler.util.pack_bitlist import pack_bitlist default_streamer = StreamerConfiguration( @@ -262,9 +263,7 @@ def _generate_setup_vals( ] @staticmethod - def get_template( - op: stream.StreamingRegionOp, - ) -> tuple[Sequence[AffineMap], Sequence[int | None]]: + def get_template(op: stream.StreamingRegionOp) -> Template: assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp) if isinstance(generic_op.body.block.first_op, kernel.QMacOp): # matmul @@ -295,4 +294,4 @@ def get_template( if not isinstance(generic_op.next_op, stream.YieldOp): raise RuntimeError("unsupported kernel") - return template, template_bounds + return Template(TemplatePattern(template_bounds, tp) for tp in template) diff --git a/compiler/ir/stream/__init__.py b/compiler/ir/stream/__init__.py new file mode 100644 index 00000000..a7086edf --- /dev/null +++ b/compiler/ir/stream/__init__.py @@ -0,0 +1,2 @@ +from .access_pattern import * +from .scheduler import * diff --git a/compiler/ir/stream/access_pattern.py b/compiler/ir/stream/access_pattern.py new file mode 100644 index 00000000..3a2172d3 --- /dev/null +++ b/compiler/ir/stream/access_pattern.py @@ -0,0 +1,236 @@ +from abc import ABC +from collections.abc import Iterable, Iterator, Sequence +from dataclasses import dataclass +from typing import Generic + +from typing_extensions import Self, TypeVar, deprecated, overload +from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap + +from compiler.util.canonicalize_affine import canonicalize_map + + +@dataclass(frozen=True) +class AccessPattern(ABC): + """ + Class specifying an access pattern for a single operand. + This is specified by bounds and a pattern given as an AffineMap. + """ + + bounds: tuple[int | None, ...] + pattern: AffineMap + + def __init__(self, bounds: Sequence[int | None], pattern: AffineMap): + # Convert bounds to a tuple + bounds = tuple(bounds) + + # Perform validations + if len(bounds) != pattern.num_dims: + raise ValueError( + "The number of bounds should be equal to the dimension of the pattern" + ) + + if pattern.num_symbols > 0: + raise ValueError("Symbols in the pattern are not supported") + + # Canonicalize the pattern + new_pattern = canonicalize_map(pattern) + + # Assign attributes using object.__setattr__ due to frozen=True + object.__setattr__(self, "bounds", bounds) + object.__setattr__(self, "pattern", new_pattern) + + @property + def num_dims(self): + return len(self.bounds) + + def disable_dims(self, dim: int) -> Self: + """ + Returns an affine map with the leftmost `dim` dimensions set to 0 + + For example: + (d0, d1, d2) -> d0 + d1 + d2 + For `dim` = 1, will return: + (d1, d2) -> d1 + d2 + For `dim` = 2, will return: + (d2) -> d2 + """ + new_pattern = self.pattern.replace_dims_and_symbols( + tuple(AffineConstantExpr(0) for _ in range(dim)) + + tuple(AffineDimExpr(i) for i in range(self.num_dims - dim)), + [], + self.num_dims - dim, + 0, + ) + return type(self)(self.bounds[dim:], new_pattern) + + +@dataclass(frozen=True) +class SchedulePattern(AccessPattern): + """ + A schedule pattern is a pattern for a schedule of an operation. + + Schedule patterns are constrained to have static bounds. + """ + + # constrain bounds to only be int + bounds: tuple[int, ...] + + def __init__(self, bounds: Sequence[int], pattern: AffineMap): + if any(bound is None or bound <= 0 for bound in bounds): + raise ValueError( + "All bounds must be static, strictly positive integers for a schedule" + ) + + super().__init__(bounds, pattern) + + def rotate(self, dim: int) -> Self: + """ + Returns a new schedule with the leftmost `dim` dimensions rotated + + For example: + (d0, d1, d2) -> 1 * d0 + 2 * d1 + 3 * d2 + For `dim` = 3, will return: + (d0, d1, d2) -> 3 * d0 + 1 * d1 + 2 * d2 + For `dim` = 2, will return: + (d0, d1, d2) -> 2 * d0 + 1 * d1 + 3 * d2 + return AccessPattern() + """ + + # Rotate in the following manner: + # (0, 1, 2, 3, ..., dim-1, dim, dim+1, ..., num_dims - 1) + # --> (1, 2, 3, ..., dim-1, 0, dim, dim+1, ..., num_dims - 1) + + new_dims = tuple(AffineDimExpr(i) for i in range(self.num_dims)) + new_dims = new_dims[1:dim] + new_dims[:1] + new_dims[dim:] + new_bounds = self.bounds[1:dim] + self.bounds[:1] + self.bounds[dim:] + + new_pattern = self.pattern.replace_dims_and_symbols( + new_dims, [], self.num_dims, 0 + ) + return type(self)(new_bounds, new_pattern) + + def tile_dim(self, dim: int, template_bound: int) -> Self: + """ + Returns a new access pattern with the `dim` dimension split up into two + This translates to creating two for loops with adjusted bounds from one for loop + + + For example: + (d0, d1, d2) -> d0 + d1 + d2 + For `dim` = 1, `template_bound` = 2: + (d0, d1, d2, d3) -> d0 + 2 * d1 + d2 + d3 + + The bounds are split in similar fashion: + For example: + [2, 8, 2] + For `dim` = 1, `template_bound` = 2: + [2, 4, 2, 2] + + """ + transform_map = AffineMap( + num_dims=self.num_dims + 1, + num_symbols=0, + # (d0, d1, d2, ..., dim-1) -> (d0, d1, d2, ..., dim-1) + results=tuple(AffineDimExpr(i) for i in range(dim)) + # (dim) -> (template_bound * dim + dim + 1) + + (AffineDimExpr(dim) * template_bound + AffineDimExpr(dim + 1),) + # (dim + 1, dim + 2, ...) -> (dim + 2, dim + 3, dim + 3) + + tuple(AffineDimExpr(i + 1) for i in range(dim + 1, self.num_dims)), + ) + new_pattern = self.pattern.compose(transform_map) + bound_to_tile = self.bounds[dim] + tiled_bound = bound_to_tile // template_bound + new_bounds = ( + self.bounds[:dim] + (tiled_bound, template_bound) + self.bounds[dim + 1 :] + ) + + return type(self)(new_bounds, new_pattern) + + +@dataclass(frozen=True) +class TemplatePattern(AccessPattern): + """ + Template pattern is a pattern for an accelerator template. + + Templates should not be transformed through either tiling/rotating/others. + """ + + def __init__(self, bounds: Sequence[int], pattern: AffineMap): + super().__init__(bounds, pattern) + + def matches(self, sp: SchedulePattern): + """ + Check if a given schedule pattern matches this + template pattern. + """ + if sp.num_dims != self.num_dims: + return False + if sp.pattern != self.pattern: + return False + return True + + +P = TypeVar("P", bound=AccessPattern) + + +class PatternCollection(Sequence[P], Generic[P], ABC): + """ + Abstract base class for collections of AccessPatterns. + Provides common methods and properties for Schedule and Template classes. + """ + + def __init__(self, patterns: Iterable[P]): + self._patterns = tuple(patterns) + + @overload + def __getitem__(self, index: int) -> P: + ... + + @overload + def __getitem__(self, index: slice) -> tuple[P]: + ... + + def __getitem__(self, index: int | slice): + return self._patterns[index] + + def __len__(self) -> int: + return len(self._patterns) + + def __iter__(self) -> Iterator[P]: + return iter(self._patterns) + + @property + @deprecated("only valid in trivial cases") + def num_dims(self) -> int: + return self[0].num_dims + + def rotate(self, dim: int) -> Self: + return type(self)(sp.rotate(dim) for sp in self) + + def disable_dims(self, dim: int) -> Self: + return type(self)(sp.disable_dims(dim) for sp in self) + + def tile_dim(self, dim: int, template_bound: int) -> Self: + return type(self)(sp.tile_dim(dim, template_bound) for sp in self) + + +class Schedule(PatternCollection[SchedulePattern]): + """ + A schedule consisting of multiple SchedulePatterns for different operands. + """ + + ... + + +class Template(PatternCollection[TemplatePattern]): + """ + A template consisting of multiple TemplatePatterns for different operands. + """ + + def matches(self, schedule: Schedule): + if len(schedule) != len(self): + return False + for sp, tp in zip(schedule, self): + if not tp.matches(sp): + return False + return True diff --git a/compiler/ir/stream/scheduler.py b/compiler/ir/stream/scheduler.py new file mode 100644 index 00000000..60f4192f --- /dev/null +++ b/compiler/ir/stream/scheduler.py @@ -0,0 +1,43 @@ +from compiler.ir.stream.access_pattern import Schedule, Template + + +def scheduler(template: Template, schedule: Schedule) -> Schedule: + for i in range(template.num_dims): + # i = 0: look at the last dimension + # i = 1: look at the second to last dimension + template_dim = template.num_dims - i - 1 + schedule_dim = schedule.num_dims - i - 1 + match = False + + # maximum number of rotations + for _ in range(schedule_dim + 1): + # check if there is a match + template_check = template.disable_dims(template_dim) + schedule_check = schedule.disable_dims(schedule_dim) + + if template_check.matches(schedule_check): + match = True + break + + # else rotate the for loops + schedule = schedule.rotate(schedule_dim + 1) + + if not match: + raise RuntimeError("failed to match template and schedule") + + # now, check bounds and design potential transformation map + if not (template_bound := template[0].bounds[template_dim]): + # nothing to worry about, continue to next dim + continue + + schedule_bound = schedule[0].bounds[schedule_dim] + + if schedule_bound < template_bound: + # need to apply padding + raise NotImplementedError("padding not supported") + elif schedule_bound >= template_bound: + # need to split up the schedule + assert schedule_bound % template_bound == 0 + schedule = schedule.tile_dim(schedule_dim, template_bound) + + return schedule diff --git a/compiler/transforms/convert_stream_to_snax_stream.py b/compiler/transforms/convert_stream_to_snax_stream.py index a72f2d8d..e533f984 100644 --- a/compiler/transforms/convert_stream_to_snax_stream.py +++ b/compiler/transforms/convert_stream_to_snax_stream.py @@ -18,7 +18,7 @@ from compiler.accelerators.snax import SNAXStreamer from compiler.dialects import snax_stream, stream from compiler.dialects.snax import StreamerConfigurationAttr -from compiler.transforms.schedule_memref_linalg import schedule_memref_linalg +from compiler.ir.stream import Schedule, SchedulePattern, scheduler @dataclass @@ -62,7 +62,7 @@ def match_and_rewrite( accelerator_type = AcceleratorRegistry().get_acc_info(acc_op) assert issubclass(accelerator_type, SNAXStreamer) - template, template_bounds = accelerator_type.get_template(op) + template = accelerator_type.get_template(op) # Make sure the operands are memrefs for memref_operand in op.operands: @@ -70,9 +70,12 @@ def match_and_rewrite( return # First, run the stream scheduling algorithm - schedule, schedule_bounds = schedule_memref_linalg( - op, template, template_bounds + schedule_bounds = tuple(op.get_static_pattern_bounds()) + schedule = Schedule( + SchedulePattern(schedule_bounds, pattern.data) + for pattern in op.patterns.data ) + schedule = scheduler(template, schedule) # We are now ready to convert the stream access patterns into snax stride patterns # construct the strided patterns for SNAX Streamers @@ -93,7 +96,7 @@ def generate_one_list(n: int, i: int): data_mem_map: AffineMap = memref_type.get_affine_map_in_bytes() # Mapping from access to data: - access_data_map: AffineMap = schedule[operand] + access_data_map: AffineMap = schedule[operand].pattern # Mapping from access to memory: access_mem_map: AffineMap = data_mem_map.compose(access_data_map) @@ -110,7 +113,7 @@ def generate_one_list(n: int, i: int): access_mem_map.eval( generate_one_list(access_mem_map.num_dims, i), () )[0], - schedule_bounds[i], + schedule[operand].bounds[i], ) for i in reversed(range(access_mem_map.num_dims)) ) @@ -123,7 +126,7 @@ def generate_one_list(n: int, i: int): upper_bounds: list[int] = [] # fill up all spatial strides - for _ in [x for x in template_bounds if x is not None]: + for _ in [x for x in template[0].bounds if x is not None]: # FIXME: provide more general solution for new spatial streamer_config # configuration, this works in all current cases and layouts but is far from generally correct. spatial_strides = [8] diff --git a/compiler/transforms/schedule_memref_linalg.py b/compiler/transforms/schedule_memref_linalg.py deleted file mode 100644 index a3725337..00000000 --- a/compiler/transforms/schedule_memref_linalg.py +++ /dev/null @@ -1,156 +0,0 @@ -from collections.abc import Sequence - -from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr, AffineMap - -from compiler.dialects import stream -from compiler.util.canonicalize_affine import canonicalize_map - - -def disable_dims(map: AffineMap, dim: int) -> AffineMap: - """ - Returns an affine map with the leftmost `dim` dimensions set to 0 - - For example: - (d0, d1, d2) -> d0 + d1 + d2 - For `dim` = 1, will return: - (d1, d2) -> d1 + d2 - For `dim` = 2, will return: - (d2) -> d2 - """ - return canonicalize_map( - map.replace_dims_and_symbols( - [AffineConstantExpr(0) for _ in range(dim)] - + [AffineDimExpr(i) for i in range(map.num_dims - dim)], - [], - map.num_dims - dim, - 0, - ) - ) - - -def rotate_dims(map: AffineMap, dim: int) -> AffineMap: - """ - Returns an affine map with the leftmost `dim` dimensions rotated - - For example: - (d0, d1, d2) -> 1 * d0 + 2 * d1 + 3 * d2 - For `dim` = 3, will return: - (d0, d1, d2) -> 3 * d0 + 1 * d1 + 2 * d2 - For `dim` = 2, will return: - (d0, d1, d2) -> 2 * d0 + 1 * d1 + 3 * d2 - """ - new_dims = [AffineDimExpr(i) for i in range(dim)] - # rotate dims by popping first and appending - new_dims.append(new_dims.pop(0)) - # keep remaining dims - new_dims = new_dims + [AffineDimExpr(i) for i in range(dim, map.num_dims)] - return canonicalize_map( - map.replace_dims_and_symbols(new_dims, [], len(new_dims), 0) - ) - - -def rotate_bounds(bounds: list[int | None], dim: int) -> list[int | None]: - """ - Returns the bounds after rotating dims, as in rotate_dims - """ - bounds = bounds.copy() - bounds.insert(0, bounds.pop(dim - 1)) - return bounds - - -def tile_dim(map: AffineMap, dim: int, template_bound: int) -> AffineMap: - """ - Returns the bounds and a new affine map with the `dim` dimension split up into two - This translates to creating two for loops with adjusted bounds from one for loop - - - For example: - (d0, d1, d2) -> d0 + d1 + d2 - For `dim` = 1, `template_bound` = 2: - (d0, d1, d2, d3) -> d0 + 2 * d1 + d2 + d3 - """ - # 1 extra dimension - # create result map (d0, d1, ... dn) - new_results: list[AffineExpr] = [AffineDimExpr(i) for i in range(map.num_dims + 1)] - # pop the result at dim - dim_sum = new_results.pop(dim) - # add it to dim multiplied by original bound // max_bound - new_results[dim] = new_results[dim] + dim_sum * template_bound - transform_map = AffineMap(map.num_dims + 1, 0, tuple(new_results)) - - result = canonicalize_map(map.compose(transform_map)) - - return result - - -def tile_bounds( - bounds: list[int | None], dim: int, template_bound: int -) -> list[int | None]: - """ - Returns the bounds after applying `tile_dim` in similar fashion. - - For example: - [2, 8, 2] - For `dim` = 1, `template_bound` = 2: - [2, 4, 2, 2] - """ - bounds = bounds.copy() - bound = bounds[dim] - bounds[dim] = template_bound - bounds.insert(dim, bound // template_bound if bound else None) - return bounds - - -def schedule_memref_linalg( - op: stream.StreamingRegionOp, - template: Sequence[AffineMap], - template_bounds: Sequence[int | None], -) -> tuple[tuple[AffineMap, ...], tuple[int, ...]]: - schedule = list(pattern.data for pattern in op.patterns.data) - schedule_bounds: list[int] = list(op.get_static_pattern_bounds()) - - for i in range(template[0].num_dims): - # i = 0: look at the last dimension - # i = 1: look at the second to last dimension - template_dim = template[0].num_dims - i - 1 - schedule_dim = schedule[0].num_dims - i - 1 - match = False - - for it in range(schedule_dim + 1): - # keep rotating the remaining dimensions until we have a match - - template_check = tuple(disable_dims(map, template_dim) for map in template) - schedule_check = tuple(disable_dims(map, schedule_dim) for map in schedule) - - if template_check == schedule_check: - match = True - break - - # else rotate the for loops - schedule = tuple(rotate_dims(map, schedule_dim + 1) for map in schedule) - schedule_bounds = rotate_bounds(schedule_bounds, schedule_dim + 1) - - if not match: - raise RuntimeError("failed to match template and schedule") - - # now, check bounds and design potential transfomration map - if not (template_bound := template_bounds[template_dim]): - # nothing to worry about, continue to next dim - continue - - schedule_bound = schedule_bounds[schedule_dim] - - if schedule_bound < template_bound: - # need to apply padding - raise NotImplementedError("padding not supported") - elif schedule_bound >= template_bound: - # need to split up the schedule - assert schedule_bound % template_bound == 0 - schedule = [ - tile_dim(schedule_map, schedule_dim, template_bound) - for schedule_map in schedule - ] - schedule_bounds = tile_bounds(schedule_bounds, schedule_dim, template_bound) - pass - - return tuple(schedule), tuple(schedule_bounds) diff --git a/tests/ir/stream/test_access_pattern.py b/tests/ir/stream/test_access_pattern.py new file mode 100644 index 00000000..a596d421 --- /dev/null +++ b/tests/ir/stream/test_access_pattern.py @@ -0,0 +1,276 @@ +import pytest +from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap + +from compiler.ir.stream import ( + AccessPattern, + Schedule, + SchedulePattern, + Template, + TemplatePattern, +) + + +# Pytest tests +def test_access_pattern_creation(): + pattern = AffineMap( + num_dims=3, + num_symbols=0, + results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + ) + bounds = (10, 20, 30) + access_pattern = AccessPattern(bounds, pattern) + assert access_pattern.bounds == bounds + assert access_pattern.pattern == pattern + assert access_pattern.num_dims == 3 + + +def test_schedule_pattern_rotate(): + pattern = AffineMap( + num_dims=3, + num_symbols=0, + results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + ) + bounds = (10, 20, 30) + access_pattern = SchedulePattern(bounds, pattern) + + # test 1: 3 dims, rotate 2 + rotated_pattern = access_pattern.rotate(2) + expected_bounds = (20, 10, 30) + expected_results = (AffineDimExpr(1), AffineDimExpr(0), AffineDimExpr(2)) + assert rotated_pattern.bounds == expected_bounds + assert rotated_pattern.pattern.results == expected_results + assert isinstance(rotated_pattern, AccessPattern) + + # test 2: 3 dims, rotate 3 + rotated_pattern = access_pattern.rotate(3) + expected_bounds = (20, 30, 10) + expected_results = (AffineDimExpr(1), AffineDimExpr(2), AffineDimExpr(0)) + assert rotated_pattern.bounds == expected_bounds + assert rotated_pattern.pattern.results == expected_results + assert isinstance(rotated_pattern, AccessPattern) + + # test 3: 3 dims, rotate 1 + rotated_pattern = access_pattern.rotate(1) + expected_bounds = (10, 20, 30) + expected_results = (AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)) + assert rotated_pattern.bounds == expected_bounds + assert rotated_pattern.pattern.results == expected_results + assert isinstance(rotated_pattern, AccessPattern) + + +def test_access_pattern_disable_dims(): + pattern = AffineMap( + num_dims=3, + num_symbols=0, + results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + ) + bounds = (10, 20, 30) + access_pattern = AccessPattern(bounds, pattern) + + # test 1: disable 0 dims (none) + disabled_pattern = access_pattern.disable_dims(0) + expected_bounds = (10, 20, 30) + expected_results = (AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)) + assert disabled_pattern.bounds == expected_bounds + assert disabled_pattern.pattern.results == expected_results + assert isinstance(disabled_pattern, AccessPattern) + + # test 2: disable 1 dims + disabled_pattern = access_pattern.disable_dims(1) + expected_bounds = (20, 30) + expected_results = (AffineConstantExpr(0), AffineDimExpr(0), AffineDimExpr(1)) + assert disabled_pattern.bounds == expected_bounds + assert disabled_pattern.pattern.results == expected_results + assert isinstance(disabled_pattern, AccessPattern) + + # test 3: disable 2 dims + disabled_pattern = access_pattern.disable_dims(2) + expected_bounds = (30,) + expected_results = (AffineConstantExpr(0), AffineConstantExpr(0), AffineDimExpr(0)) + assert disabled_pattern.bounds == expected_bounds + assert disabled_pattern.pattern.results == expected_results + assert isinstance(disabled_pattern, AccessPattern) + + # test 4: disable 3 dims (all) + disabled_pattern = access_pattern.disable_dims(3) + expected_bounds = tuple() + expected_results = ( + AffineConstantExpr(0), + AffineConstantExpr(0), + AffineConstantExpr(0), + ) + assert disabled_pattern.bounds == expected_bounds + assert disabled_pattern.pattern.results == expected_results + assert isinstance(disabled_pattern, AccessPattern) + + +def test_schedule_pattern_tile_dim(): + pattern = AffineMap( + num_dims=3, + num_symbols=0, + results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + ) + bounds = (10, 20, 30) + access_pattern = SchedulePattern(bounds, pattern) + tiled_pattern = access_pattern.tile_dim(1, 5) + expected_bounds = (10, 4, 5, 30) + expected_results = ( + AffineDimExpr(0), + AffineDimExpr(1) * 5 + AffineDimExpr(2), + AffineDimExpr(3), + ) + assert tiled_pattern.bounds == expected_bounds + assert tiled_pattern.pattern.results == expected_results + assert isinstance(tiled_pattern, AccessPattern) + + +def test_template_pattern_creation(): + pattern = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) + ) + bounds = (5, 10) + template_pattern = TemplatePattern(bounds, pattern) + assert template_pattern.bounds == bounds + assert template_pattern.pattern == pattern + assert template_pattern.num_dims == 2 + + +def test_schedule_pattern_creation(): + pattern = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) + ) + bounds = (15, 25) + schedule_pattern = SchedulePattern(bounds, pattern) + assert schedule_pattern.bounds == bounds + assert schedule_pattern.pattern == pattern + assert schedule_pattern.num_dims == 2 + assert isinstance(schedule_pattern.bounds, tuple) + assert all(isinstance(b, int) for b in schedule_pattern.bounds) + + +def test_schedule_pattern_invalid_bounds(): + pattern = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) + ) + with pytest.raises( + ValueError, + match="All bounds must be static, strictly positive integers for a schedule", + ): + SchedulePattern((10, None), pattern) # pyright: ignore + with pytest.raises( + ValueError, + match="All bounds must be static, strictly positive integers for a schedule", + ): + SchedulePattern((10, -5), pattern) + + +def test_template_pattern_matches(): + pattern = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) + ) + bounds = (10, 20) + tp = TemplatePattern(bounds, pattern) + sp_matching = SchedulePattern(bounds, pattern) + sp_non_matching_pattern = SchedulePattern( + bounds, + AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(1), AffineDimExpr(0)) + ), + ) + sp_non_matching_bounds = SchedulePattern((5, 15), pattern) + + assert tp.matches(sp_matching) is True + assert tp.matches(sp_non_matching_pattern) is False + assert ( + tp.matches(sp_non_matching_bounds) is True + ) # Bounds are not checked in matches + + +def test_schedule_rotate(): + pattern1 = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) + ) + pattern2 = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(1), AffineDimExpr(0)) + ) + sp1 = SchedulePattern((10, 20), pattern1) + sp2 = SchedulePattern((30, 40), pattern2) + schedule = Schedule([sp1, sp2]) + rotated_schedule = schedule.rotate(1) + assert isinstance(rotated_schedule, Schedule) + assert rotated_schedule[0].bounds == sp1.rotate(1).bounds + assert rotated_schedule[1].bounds == sp2.rotate(1).bounds + + +def test_schedule_disable_dims(): + pattern1 = AffineMap( + num_dims=3, + num_symbols=0, + results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + ) + sp1 = SchedulePattern((10, 20, 30), pattern1) + schedule = Schedule([sp1]) + disabled_schedule = schedule.disable_dims(2) + assert isinstance(disabled_schedule, Schedule) + assert disabled_schedule[0].bounds == sp1.disable_dims(2).bounds + + +def test_schedule_tile_dim(): + pattern1 = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) + ) + sp1 = SchedulePattern((100, 200), pattern1) + schedule = Schedule([sp1]) + tiled_schedule = schedule.tile_dim(0, 10) + assert isinstance(tiled_schedule, Schedule) + expected_bounds = sp1.tile_dim(0, 10).bounds + assert tiled_schedule[0].bounds == expected_bounds + + +def test_template_disable_dims(): + pattern1 = AffineMap( + num_dims=3, + num_symbols=0, + results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + ) + tp1 = TemplatePattern((10, 20, 30), pattern1) + template = Template([tp1]) + disabled_template = template.disable_dims(1) + assert isinstance(disabled_template, Template) + assert disabled_template[0].bounds == tp1.disable_dims(1).bounds + + +def test_template_matches_schedule(): + pattern1 = AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) + ) + tp1 = TemplatePattern((10, 20), pattern1) + tp2 = TemplatePattern((30, 40), pattern1) + template = Template([tp1, tp2]) + + sp1 = SchedulePattern((10, 20), pattern1) + sp2 = SchedulePattern((30, 40), pattern1) + schedule_matching = Schedule([sp1, sp2]) + + sp3 = SchedulePattern( + (10, 20), + AffineMap( + num_dims=2, num_symbols=0, results=(AffineDimExpr(1), AffineDimExpr(0)) + ), + ) + schedule_non_matching = Schedule([sp1, sp3]) + + assert template.matches(schedule_matching) is True + assert template.matches(schedule_non_matching) is False + + +def test_template_matches_schedule_length_mismatch(): + pattern1 = AffineMap(num_dims=1, num_symbols=0, results=(AffineDimExpr(0),)) + tp1 = TemplatePattern((10,), pattern1) + template = Template([tp1]) + + sp1 = SchedulePattern((10,), pattern1) + sp2 = SchedulePattern((20,), pattern1) + schedule = Schedule([sp1, sp2]) + + assert template.matches(schedule) is False