Skip to content

Commit

Permalink
use backtracking scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 27, 2025
1 parent c5029aa commit f7dd619
Showing 1 changed file with 83 additions and 36 deletions.
119 changes: 83 additions & 36 deletions compiler/ir/dart/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,90 @@
from collections.abc import Generator

from compiler.ir.dart.access_pattern import Schedule, Template


def scheduler(template: Template, schedule: Schedule) -> Schedule:
for i in range(template.num_dims):
# i = 0: look at the last dimension
# i = 1: look at the second to last dimension
template_dim = template.num_dims - i - 1
schedule_dim = schedule.num_dims - i - 1
match = False

# maximum number of rotations
for _ in range(schedule_dim + 1):
# check if there is a match
template_check = template.disable_dims(template_dim)
schedule_check = schedule.disable_dims(schedule_dim)

if template_check.matches(schedule_check):
match = True
break

# else rotate the for loops
schedule = schedule.rotate(schedule_dim + 1)

if not match:
raise RuntimeError("failed to match template and schedule")

# now, check bounds and design potential transformation map
if not (template_bound := template[0].bounds[template_dim]):
# nothing to worry about, continue to next dim
def scheduler_backtrack(
template: Template,
schedule: Schedule,
inner_dims: int = 1,
) -> Generator[Schedule]:
"""
Backtracking method to find all possible mappings of the schedule on the template
`template` (Template): the accelerator template
`schedule` (Schedule): the partially scheduled operation
`inner_dims` (int): current number of innermost dimensions being handled
`pure_output_stationary` (bool):
"""

"""
Explanation of dimensions:
In case we are handling 6 dimensions and `dim` = 3:
There are 3 innermost dimensions that are being checked with the template.
The other outermost dimensions are not considered.
The current dimension under foces (d3) is the most outermost dim of the innermost dims.
When we apply a tiling, this will happen to this dimension d3.
When we apply a rotation, we will rotate outermost dims + focused dim (d0 - d4)
+---- `dim` innermost dims
|
-----------+
d0, d1, d2, d3, d4, d5
--+
-----------+ +-------------- focused dim
|
+----------------- `outermost` dims
"""

# exit condition for the algorithm: if all dimensions are considered
if inner_dims == schedule.num_dims:
yield schedule

# This for loop rotates the outermost + focused dims loops in all ways.
# There are thus `schedule.num_dims - inner_dims + 1` different rotations possible.
for _ in range(schedule.num_dims - inner_dims + 1):
# apply rotation:
schedule = schedule.rotate(schedule.num_dims - inner_dims + 1)

# use innermost dimensions for template check
schedule_check = schedule.inner_dims(inner_dims)
template_check = template.inner_dims(inner_dims)

# check 1: check for valid transformation
if not template_check.matches(schedule_check):
# not possible, consider next option
continue

schedule_bound = schedule[0].bounds[schedule_dim]
# checks passed, we have a candidate schedule now
candidate_schedule = schedule

if schedule_bound < template_bound:
# need to apply padding
raise NotImplementedError("padding not supported")
elif schedule_bound >= template_bound:
# need to split up the schedule
assert schedule_bound % template_bound == 0
schedule = schedule.tile_dim(schedule_dim, template_bound)
# check 2: check for valid iteration bounds
template_bound = (
template[0].bounds[-inner_dims] if inner_dims <= template.num_dims else None
)
schedule_bound = candidate_schedule[0].bounds[-inner_dims]

return schedule
if template_bound:
if schedule_bound < template_bound:
# TODO: underutilized array, apply padding
continue
elif schedule_bound % template_bound != 0:
# TODO: imperfect factorization
continue
else: # >=
# tile schedule
candidate_schedule = candidate_schedule.tile_dim(
schedule.num_dims - inner_dims, template_bound
)

# continue with candidate schedule, with an extra inner dim:
yield from scheduler_backtrack(template, candidate_schedule, inner_dims + 1)


def scheduler(template: Template, schedule: Schedule) -> Schedule:
# for now just return the first result of the backtracking
result = next(scheduler_backtrack(template, schedule))
return result

0 comments on commit f7dd619

Please sign in to comment.