-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c5029aa
commit f7dd619
Showing
1 changed file
with
83 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,90 @@ | ||
from collections.abc import Generator | ||
|
||
from compiler.ir.dart.access_pattern import Schedule, Template | ||
|
||
|
||
def scheduler(template: Template, schedule: Schedule) -> Schedule: | ||
for i in range(template.num_dims): | ||
# i = 0: look at the last dimension | ||
# i = 1: look at the second to last dimension | ||
template_dim = template.num_dims - i - 1 | ||
schedule_dim = schedule.num_dims - i - 1 | ||
match = False | ||
|
||
# maximum number of rotations | ||
for _ in range(schedule_dim + 1): | ||
# check if there is a match | ||
template_check = template.disable_dims(template_dim) | ||
schedule_check = schedule.disable_dims(schedule_dim) | ||
|
||
if template_check.matches(schedule_check): | ||
match = True | ||
break | ||
|
||
# else rotate the for loops | ||
schedule = schedule.rotate(schedule_dim + 1) | ||
|
||
if not match: | ||
raise RuntimeError("failed to match template and schedule") | ||
|
||
# now, check bounds and design potential transformation map | ||
if not (template_bound := template[0].bounds[template_dim]): | ||
# nothing to worry about, continue to next dim | ||
def scheduler_backtrack( | ||
template: Template, | ||
schedule: Schedule, | ||
inner_dims: int = 1, | ||
) -> Generator[Schedule]: | ||
""" | ||
Backtracking method to find all possible mappings of the schedule on the template | ||
`template` (Template): the accelerator template | ||
`schedule` (Schedule): the partially scheduled operation | ||
`inner_dims` (int): current number of innermost dimensions being handled | ||
`pure_output_stationary` (bool): | ||
""" | ||
|
||
""" | ||
Explanation of dimensions: | ||
In case we are handling 6 dimensions and `dim` = 3: | ||
There are 3 innermost dimensions that are being checked with the template. | ||
The other outermost dimensions are not considered. | ||
The current dimension under foces (d3) is the most outermost dim of the innermost dims. | ||
When we apply a tiling, this will happen to this dimension d3. | ||
When we apply a rotation, we will rotate outermost dims + focused dim (d0 - d4) | ||
+---- `dim` innermost dims | ||
| | ||
-----------+ | ||
d0, d1, d2, d3, d4, d5 | ||
--+ | ||
-----------+ +-------------- focused dim | ||
| | ||
+----------------- `outermost` dims | ||
""" | ||
|
||
# exit condition for the algorithm: if all dimensions are considered | ||
if inner_dims == schedule.num_dims: | ||
yield schedule | ||
|
||
# This for loop rotates the outermost + focused dims loops in all ways. | ||
# There are thus `schedule.num_dims - inner_dims + 1` different rotations possible. | ||
for _ in range(schedule.num_dims - inner_dims + 1): | ||
# apply rotation: | ||
schedule = schedule.rotate(schedule.num_dims - inner_dims + 1) | ||
|
||
# use innermost dimensions for template check | ||
schedule_check = schedule.inner_dims(inner_dims) | ||
template_check = template.inner_dims(inner_dims) | ||
|
||
# check 1: check for valid transformation | ||
if not template_check.matches(schedule_check): | ||
# not possible, consider next option | ||
continue | ||
|
||
schedule_bound = schedule[0].bounds[schedule_dim] | ||
# checks passed, we have a candidate schedule now | ||
candidate_schedule = schedule | ||
|
||
if schedule_bound < template_bound: | ||
# need to apply padding | ||
raise NotImplementedError("padding not supported") | ||
elif schedule_bound >= template_bound: | ||
# need to split up the schedule | ||
assert schedule_bound % template_bound == 0 | ||
schedule = schedule.tile_dim(schedule_dim, template_bound) | ||
# check 2: check for valid iteration bounds | ||
template_bound = ( | ||
template[0].bounds[-inner_dims] if inner_dims <= template.num_dims else None | ||
) | ||
schedule_bound = candidate_schedule[0].bounds[-inner_dims] | ||
|
||
return schedule | ||
if template_bound: | ||
if schedule_bound < template_bound: | ||
# TODO: underutilized array, apply padding | ||
continue | ||
elif schedule_bound % template_bound != 0: | ||
# TODO: imperfect factorization | ||
continue | ||
else: # >= | ||
# tile schedule | ||
candidate_schedule = candidate_schedule.tile_dim( | ||
schedule.num_dims - inner_dims, template_bound | ||
) | ||
|
||
# continue with candidate schedule, with an extra inner dim: | ||
yield from scheduler_backtrack(template, candidate_schedule, inner_dims + 1) | ||
|
||
|
||
def scheduler(template: Template, schedule: Schedule) -> Schedule: | ||
# for now just return the first result of the backtracking | ||
result = next(scheduler_backtrack(template, schedule)) | ||
return result |