Skip to content

Commit

Permalink
add optional output stationary check
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 28, 2025
1 parent 3603842 commit ebb1424
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 6 deletions.
53 changes: 48 additions & 5 deletions compiler/ir/dart/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections.abc import Generator
from collections.abc import Callable, Generator, Sequence

import numpy as np

from compiler.ir.dart.access_pattern import Schedule, Template

Expand All @@ -7,6 +9,7 @@ def scheduler_backtrack(
template: Template,
schedule: Schedule,
inner_dims: int = 1,
extra_checks: Sequence[Callable[[Template, Schedule], bool]] = [],
) -> Generator[Schedule]:
"""
Backtracking method to find all possible mappings of the schedule on the template
Expand Down Expand Up @@ -58,10 +61,15 @@ def scheduler_backtrack(
# not possible, consider next option
continue

# check 2: apply extra checks
if not all(check(template_check, schedule_check) for check in extra_checks):
# not a valid schedule, consider next option
continue

# checks passed, we have a candidate schedule now
candidate_schedule = schedule

# check 2: check for valid iteration bounds
# check 3: check for valid iteration bounds
template_bound = (
template[0].bounds[-inner_dims] if inner_dims <= template.num_dims else None
)
Expand All @@ -81,10 +89,45 @@ def scheduler_backtrack(
)

# continue with candidate schedule, with an extra inner dim:
yield from scheduler_backtrack(template, candidate_schedule, inner_dims + 1)
yield from scheduler_backtrack(
template, candidate_schedule, inner_dims + 1, extra_checks
)


def is_pure_output_stationary(template: Template, schedule: Schedule):
"""
Checks whether a schedule, outside of the template, is fully output
stationary. This is determined by making sure all parallel dimensions
preced the reduction dimensions in the output operand (last operand).
"""
# fetch the pattern of the last operand
output_schedule = schedule[-1].pattern.A
# do not consider template dims
output_schedule = output_schedule[:, : -template.num_dims]

# check whether there are any non-zero elements in every column
# create iteration_types list with False for reduction, True for parallel
iteration_types: list[bool] = list(
map(lambda x: bool(x), np.any(output_schedule != 0, axis=0).tolist())
)
# the first zero should come after the last 1 for output stationary

# if only reduction, or only parallel, pure otuput stationary is guaranteed
if not (True in iteration_types and False in iteration_types):
return True

first_reduction_idx = iteration_types.index(False)
last_parallel_idx = len(iteration_types) - 1 - iteration_types[::-1].index(True)

# last parallel index should come before first reduction idx for pure output stationarity
return first_reduction_idx > last_parallel_idx

def scheduler(template: Template, schedule: Schedule) -> Schedule:

def scheduler(
template: Template,
schedule: Schedule,
extra_checks: Sequence[Callable[[Template, Schedule], bool]] = [],
) -> Schedule:
# for now just return the first result of the backtracking
result = next(scheduler_backtrack(template, schedule))
result = next(scheduler_backtrack(template, schedule, extra_checks=extra_checks))
return result
70 changes: 69 additions & 1 deletion tests/ir/dart/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
Template,
TemplatePattern,
)
from compiler.ir.dart.scheduler import scheduler
from compiler.ir.dart.scheduler import (
is_pure_output_stationary,
scheduler,
scheduler_backtrack,
)


def test_matching_1o():
Expand Down Expand Up @@ -143,3 +147,67 @@ def test_tiling_1o_1d2():
result = scheduler(template, schedule)

assert result == expected


def test_pure_output_stationary_check():
template_pattern = AffineMap.from_callable(lambda e: (e, e, e))

schedule_checks: list[tuple[AffineMap, bool]] = [
# only template dim: is valid
(AffineMap.from_callable(lambda e: (e, e, e)), True),
# only parallel dims: is valid
(AffineMap.from_callable(lambda c, d, e: (c + e, d + e, e)), True),
# only reduction dims: is valid
(AffineMap.from_callable(lambda _c, _d, e: (e, e, e)), True),
# parallel dim before reduction dim: valid
(AffineMap.from_callable(lambda c, _, e: (e, e, 2 * c + e)), True),
# reduction dim before parallel dim: invalid
(AffineMap.from_callable(lambda _, d, e: (e, e, 2 * d + e)), False),
# some more complex mixtures of parallel dim / reduction dim
(AffineMap.from_callable(lambda a, b, _c, _d, e: (e, b + e, 2 * a + e)), True),
(AffineMap.from_callable(lambda _a, _b, c, d, e: (e, c + e, 2 * d + e)), False),
(AffineMap.from_callable(lambda a, _b, c, _d, e: (e, a + e, 2 * c + e)), False),
(AffineMap.from_callable(lambda _a, b, _c, d, e: (e, b + e, 2 * d + e)), False),
]

template = Template(
(TemplatePattern([1] * template_pattern.num_dims, template_pattern),)
)

for schedule_pattern, expected_result in schedule_checks:
schedule = Schedule(
(SchedulePattern([1] * schedule_pattern.num_dims, schedule_pattern),)
)
assert is_pure_output_stationary(template, schedule) is expected_result


def test_pure_output_stationary_scheduler():
template_pattern = AffineMap.from_callable(lambda y: (y,))
template = Template([TemplatePattern([4], template_pattern)])

schedule_pattern = AffineMap.from_callable(lambda x, y: (y,))
schedule = Schedule([SchedulePattern([8, 8], schedule_pattern)])

# the expected output stationary schedule
output_stationary_pattern = AffineMap.from_callable(
lambda y0, x, y1: (4 * y0 + y1,)
)
schedule_output_stationary = Schedule(
[SchedulePattern([2, 8, 4], output_stationary_pattern)]
)

# if we run the scheduler without constraints, there are 2 valid schedules:
result = list(scheduler_backtrack(template, schedule, extra_checks=[]))
assert len(result) == 2
# one of which is the output stationary one:
assert schedule_output_stationary in result

# if we run the scheduler with the pure output stationary constraint, there is 1:
result = list(
scheduler_backtrack(
template, schedule, extra_checks=[is_pure_output_stationary]
)
)
assert len(result) == 1
# that one result being the output stationary one
assert result[0] == schedule_output_stationary

0 comments on commit ebb1424

Please sign in to comment.