From b26d720af778915e212b540a43edbdbd6279d234 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Mon, 27 Jan 2025 15:32:28 +0100 Subject: [PATCH] template matching: only look at the inner dims of the schedule --- compiler/ir/dart/access_pattern.py | 2 ++ tests/ir/dart/test_access_pattern.py | 24 +++++++++++++++++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/compiler/ir/dart/access_pattern.py b/compiler/ir/dart/access_pattern.py index 001689ed..b06610b0 100644 --- a/compiler/ir/dart/access_pattern.py +++ b/compiler/ir/dart/access_pattern.py @@ -196,6 +196,8 @@ def matches(self, sp: SchedulePattern): Check if a given schedule pattern matches this template pattern. """ + if sp.num_dims > self.num_dims: + sp = sp.inner_dims(self.num_dims) if sp.num_dims != self.num_dims: return False if sp.pattern != self.pattern: diff --git a/tests/ir/dart/test_access_pattern.py b/tests/ir/dart/test_access_pattern.py index d62bbde7..0bbd6567 100644 --- a/tests/ir/dart/test_access_pattern.py +++ b/tests/ir/dart/test_access_pattern.py @@ -226,20 +226,34 @@ def test_template_pattern_matches(): ) bounds = (10, 20) tp = TemplatePattern(bounds, pattern) + + # test matching pattern sp_matching = SchedulePattern(bounds, pattern) + assert tp.matches(sp_matching) + + # test non matching pattern sp_non_matching_pattern = SchedulePattern( bounds, AffineMap( num_dims=2, num_symbols=0, results=(AffineDimExpr(1), AffineDimExpr(0)) ), ) + assert tp.matches(sp_non_matching_pattern) is False + + # check pattern with wrong bounds (should be irellevant for template check) sp_non_matching_bounds = SchedulePattern((5, 15), pattern) + assert tp.matches(sp_non_matching_bounds) - assert tp.matches(sp_matching) is True - assert tp.matches(sp_non_matching_pattern) is False - assert ( - tp.matches(sp_non_matching_bounds) is True - ) # Bounds are not checked in matches + # if the schedule has higher dimensionality than the template, only the innermost + # are considered + larger_pattern = AffineMap( + num_dims=3, + num_symbols=0, + results=(AffineDimExpr(1) + AffineDimExpr(0), AffineDimExpr(2)), + ) + larger_bounds = (44, 10, 20) + sp_matching_larger = SchedulePattern(larger_bounds, larger_pattern) + assert tp.matches(sp_matching_larger) def test_schedule_rotate():