From 62c4caa1303be73f0e78b6c2ad505c748d434f93 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Thu, 23 Nov 2023 18:28:53 +0000 Subject: [PATCH] compiler: regain correctness for indirections --- devito/ir/clusters/cluster.py | 7 ++++++- devito/ir/support/utils.py | 1 - devito/types/dimension.py | 2 +- examples/userapi/02_apply.ipynb | 4 ++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 248e44f6688..7a0ab4e42e7 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -367,7 +367,12 @@ def dspace(self): for f, v in parts.items(): for i in v: if i.dim in oobs: - intervals = intervals.ceil(v[i.dim]) + try: + if intervals[i.dim].upper > v[i.dim].upper and \ + bool(i.dim in f.dimensions): + intervals = intervals.ceil(v[i.dim]) + except AttributeError: + pass return DataSpace(intervals, parts) diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index a2dc2b4c730..314e27e7923 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -127,7 +127,6 @@ def detect_accesses(exprs): mapper = defaultdict(Stencil) for e in retrieve_indexed(exprs, deep=True): f = e.function - for a, d0 in zip(e.indices, f.dimensions): if isinstance(a, Indirection): a = a.mapped diff --git a/devito/types/dimension.py b/devito/types/dimension.py index bcd3b9af9e5..839c60be609 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -342,7 +342,7 @@ def _arg_check(self, args, size, interval): # Autopadding causes non-integer upper limit from devito.symbolics import normalize_args upper = interval.upper.subs(normalize_args(args)) - if args[self.max_name] + upper > size: + if args[self.max_name] + upper >= size: raise InvalidArgument("OOB detected due to %s=%d" % (self.max_name, args[self.max_name])) diff --git a/examples/userapi/02_apply.ipynb b/examples/userapi/02_apply.ipynb index 80584051bda..f8f730dd522 100644 --- a/examples/userapi/02_apply.ipynb +++ b/examples/userapi/02_apply.ipynb @@ -250,14 +250,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "OOB detected due to time_M=3\n" + "OOB detected due to time_M=2\n" ] } ], "source": [ "from devito.exceptions import InvalidArgument\n", "try:\n", - " op.apply(time_M=3)\n", + " op.apply(time_M=2)\n", "except InvalidArgument as e:\n", " print(e)" ]