From c5733c70d68564857627ec1fdb153ca14b80f035 Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 25 Sep 2023 10:17:32 -0400 Subject: [PATCH] misc: rework multituple fir easier use --- devito/passes/iet/languages/openacc.py | 10 +-- devito/tools/data_structures.py | 117 +++++++++++++++---------- 2 files changed, 72 insertions(+), 55 deletions(-) diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index 939a68f3047..1eb8f377fc5 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -30,7 +30,8 @@ def _make_clauses(cls, ncollapsed=0, reduction=None, tile=None, **kwargs): clauses = [] if tile: - clauses.append('tile(%s)' % ','.join(str(i) for i in tile)) + n = max(len(tile), ncollapsed) + clauses.append('tile(%s)' % ','.join(str(tile.next()) for _ in range(n))) elif ncollapsed > 1: clauses.append('collapse(%d)' % ncollapsed) @@ -164,13 +165,8 @@ def _make_partree(self, candidates, nthreads=None): if self._is_offloadable(root) and \ all(i.is_Affine for i in [root] + collapsable) and \ self.par_tile: - tile = self.par_tile.next() + tile = tuple(self.par_tile.next().next() for n in range(ncollapsable)) assert isinstance(tile, tuple) - nremainder = (ncollapsable + 1) - len(tile) - if nremainder >= 0: - tile += (tile[-1],)*nremainder - else: - tile = tile[:ncollapsable + 1] body = self.DeviceIteration(gpu_fit=self.gpu_fit, tile=tile, ncollapsed=ncollapsable, **root.args) diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index 95f9f65d0e9..828c7dd665f 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -552,7 +552,65 @@ def __hash__(self): return self._hash -class UnboundedMultiTuple(object): +class UnboundTuple(object): + """ + An UnboundedTuple is a tuple that can be + infinitely iterated over. + + Examples + -------- + >>> ub = UnboundTuple((1,2,3)) + >>> ub + UnboundTuple(UnboundTuple(1, 2), UnboundTuple(3, 4)) + >>> ub.next() + UnboundTuple(1, 2) + >>> ub.next() + UnboundTuple(3, 4) + >>> ub.next() + UnboundTuple(3, 4) + """ + + def __init__(self, *items): + nitems = [] + for i in as_tuple(items): + if isinstance(i, Iterable): + nitems.append(UnboundTuple(*i)) + elif i is not None: + nitems.append(i) + + self.items = tuple(nitems) + self.last = len(self.items) + self.current = 0 + + @property + def default(self): + return self.items[0] + + def next(self): + if self.last == 0: + return None + item = self.items[self.current] + if self.current == self.last-1 or self.current == -1: + self.current = -1 + else: + self.current += 1 + return item + + def __len__(self): + return self.last + + def __repr__(self): + sitems = [s.__repr__() for s in self.items] + return "%s(%s)" % (self.__class__.__name__, ", ".join(sitems)) + + def __getitem__(self, i): + if i > self.last: + return self.items[self.last] + else: + return self.items[i] + + +class UnboundedMultiTuple(UnboundTuple): """ An UnboundedMultiTuple is an ordered collection of tuples that can be @@ -562,10 +620,10 @@ class UnboundedMultiTuple(object): -------- >>> ub = UnboundedMultiTuple([1, 2], [3, 4]) >>> ub - UnboundedMultiTuple((1, 2), (3, 4)) + UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4)) >>> ub.iter() >>> ub - UnboundedMultiTuple(*(1, 2), (3, 4)) + UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4)) >>> ub.next() 1 >>> ub.next() @@ -574,7 +632,7 @@ class UnboundedMultiTuple(object): >>> ub.iter() # No effect, tip has reached the last tuple >>> ub.iter() # No effect, tip has reached the last tuple >>> ub - UnboundedMultiTuple((1, 2), *(3, 4)) + UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4)) >>> ub.next() 3 >>> ub.next() @@ -585,52 +643,15 @@ class UnboundedMultiTuple(object): """ def __init__(self, *items): - # Normalize input - nitems = [] - for i in as_tuple(items): - if isinstance(i, Iterable): - nitems.append(tuple(i)) - else: - raise ValueError("Expected sequence, got %s" % type(i)) - - self.items = tuple(nitems) - self.tip = -1 - self.curiter = None - - def __repr__(self): - items = [str(i) for i in self.items] - if self.curiter is not None: - items[self.tip] = "*%s" % items[self.tip] - return "%s(%s)" % (self.__class__.__name__, ", ".join(items)) + super().__init__(*items) + self.current = -1 def iter(self): - if not self.items: - raise ValueError("No tuples available") - self.tip = min(self.tip + 1, max(len(self.items) - 1, 0)) - self.curiter = iter(self.items[self.tip]) + self.current = min(self.current + 1, self.last - 1) + self.items[self.current].current = 0 + return def next(self): - if self.curiter is None: + if self.items[self.current].current == -1: raise StopIteration - return next(self.curiter) - - -class UnboundTuple(object): - """ - A simple data structure that returns the last element forever once reached - """ - - def __init__(self, items): - self.items = as_tuple(items) - self.last = len(self.items) - self.current = 0 - - def next(self): - if self.last == 0: - return None - item = self.items[self.current] - self.current = min(self.last - 1, self.current+1) - return item - - def __len__(self): - return self.last + return self.items[self.current].next()