Skip to content

Commit

Permalink
template matching: only look at the inner dims of the schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 29, 2025
1 parent 6c31364 commit b26d720
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
2 changes: 2 additions & 0 deletions compiler/ir/dart/access_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def matches(self, sp: SchedulePattern):
Check if a given schedule pattern matches this
template pattern.
"""
if sp.num_dims > self.num_dims:
sp = sp.inner_dims(self.num_dims)
if sp.num_dims != self.num_dims:
return False
if sp.pattern != self.pattern:
Expand Down
24 changes: 19 additions & 5 deletions tests/ir/dart/test_access_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,34 @@ def test_template_pattern_matches():
)
bounds = (10, 20)
tp = TemplatePattern(bounds, pattern)

# test matching pattern
sp_matching = SchedulePattern(bounds, pattern)
assert tp.matches(sp_matching)

# test non matching pattern
sp_non_matching_pattern = SchedulePattern(
bounds,
AffineMap(
num_dims=2, num_symbols=0, results=(AffineDimExpr(1), AffineDimExpr(0))
),
)
assert tp.matches(sp_non_matching_pattern) is False

# check pattern with wrong bounds (should be irellevant for template check)
sp_non_matching_bounds = SchedulePattern((5, 15), pattern)
assert tp.matches(sp_non_matching_bounds)

assert tp.matches(sp_matching) is True
assert tp.matches(sp_non_matching_pattern) is False
assert (
tp.matches(sp_non_matching_bounds) is True
) # Bounds are not checked in matches
# if the schedule has higher dimensionality than the template, only the innermost
# are considered
larger_pattern = AffineMap(
num_dims=3,
num_symbols=0,
results=(AffineDimExpr(1) + AffineDimExpr(0), AffineDimExpr(2)),
)
larger_bounds = (44, 10, 20)
sp_matching_larger = SchedulePattern(larger_bounds, larger_pattern)
assert tp.matches(sp_matching_larger)


def test_schedule_rotate():
Expand Down

0 comments on commit b26d720

Please sign in to comment.