Skip to content

Commit

Permalink
misc: rework multituple fir easier use
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 25, 2023
1 parent ccd4f72 commit c5733c7
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 55 deletions.
10 changes: 3 additions & 7 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def _make_clauses(cls, ncollapsed=0, reduction=None, tile=None, **kwargs):
clauses = []

if tile:
clauses.append('tile(%s)' % ','.join(str(i) for i in tile))
n = max(len(tile), ncollapsed)
clauses.append('tile(%s)' % ','.join(str(tile.next()) for _ in range(n)))
elif ncollapsed > 1:
clauses.append('collapse(%d)' % ncollapsed)

Expand Down Expand Up @@ -164,13 +165,8 @@ def _make_partree(self, candidates, nthreads=None):
if self._is_offloadable(root) and \
all(i.is_Affine for i in [root] + collapsable) and \
self.par_tile:
tile = self.par_tile.next()
tile = tuple(self.par_tile.next().next() for n in range(ncollapsable))
assert isinstance(tile, tuple)
nremainder = (ncollapsable + 1) - len(tile)
if nremainder >= 0:
tile += (tile[-1],)*nremainder
else:
tile = tile[:ncollapsable + 1]

body = self.DeviceIteration(gpu_fit=self.gpu_fit, tile=tile,
ncollapsed=ncollapsable, **root.args)
Expand Down
117 changes: 69 additions & 48 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,65 @@ def __hash__(self):
return self._hash


class UnboundedMultiTuple(object):
class UnboundTuple(object):
"""
An UnboundedTuple is a tuple that can be
infinitely iterated over.
Examples
--------
>>> ub = UnboundTuple((1,2,3))
>>> ub
UnboundTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.next()
UnboundTuple(1, 2)
>>> ub.next()
UnboundTuple(3, 4)
>>> ub.next()
UnboundTuple(3, 4)
"""

def __init__(self, *items):
nitems = []
for i in as_tuple(items):
if isinstance(i, Iterable):
nitems.append(UnboundTuple(*i))
elif i is not None:
nitems.append(i)

self.items = tuple(nitems)
self.last = len(self.items)
self.current = 0

@property
def default(self):
return self.items[0]

def next(self):
if self.last == 0:
return None
item = self.items[self.current]
if self.current == self.last-1 or self.current == -1:
self.current = -1
else:
self.current += 1
return item

def __len__(self):
return self.last

def __repr__(self):
sitems = [s.__repr__() for s in self.items]
return "%s(%s)" % (self.__class__.__name__, ", ".join(sitems))

def __getitem__(self, i):
if i > self.last:
return self.items[self.last]
else:
return self.items[i]


class UnboundedMultiTuple(UnboundTuple):

"""
An UnboundedMultiTuple is an ordered collection of tuples that can be
Expand All @@ -562,10 +620,10 @@ class UnboundedMultiTuple(object):
--------
>>> ub = UnboundedMultiTuple([1, 2], [3, 4])
>>> ub
UnboundedMultiTuple((1, 2), (3, 4))
UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.iter()
>>> ub
UnboundedMultiTuple(*(1, 2), (3, 4))
UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.next()
1
>>> ub.next()
Expand All @@ -574,7 +632,7 @@ class UnboundedMultiTuple(object):
>>> ub.iter() # No effect, tip has reached the last tuple
>>> ub.iter() # No effect, tip has reached the last tuple
>>> ub
UnboundedMultiTuple((1, 2), *(3, 4))
UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.next()
3
>>> ub.next()
Expand All @@ -585,52 +643,15 @@ class UnboundedMultiTuple(object):
"""

def __init__(self, *items):
# Normalize input
nitems = []
for i in as_tuple(items):
if isinstance(i, Iterable):
nitems.append(tuple(i))
else:
raise ValueError("Expected sequence, got %s" % type(i))

self.items = tuple(nitems)
self.tip = -1
self.curiter = None

def __repr__(self):
items = [str(i) for i in self.items]
if self.curiter is not None:
items[self.tip] = "*%s" % items[self.tip]
return "%s(%s)" % (self.__class__.__name__, ", ".join(items))
super().__init__(*items)
self.current = -1

def iter(self):
if not self.items:
raise ValueError("No tuples available")
self.tip = min(self.tip + 1, max(len(self.items) - 1, 0))
self.curiter = iter(self.items[self.tip])
self.current = min(self.current + 1, self.last - 1)
self.items[self.current].current = 0
return

def next(self):
if self.curiter is None:
if self.items[self.current].current == -1:
raise StopIteration
return next(self.curiter)


class UnboundTuple(object):
"""
A simple data structure that returns the last element forever once reached
"""

def __init__(self, items):
self.items = as_tuple(items)
self.last = len(self.items)
self.current = 0

def next(self):
if self.last == 0:
return None
item = self.items[self.current]
self.current = min(self.last - 1, self.current+1)
return item

def __len__(self):
return self.last
return self.items[self.current].next()

0 comments on commit c5733c7

Please sign in to comment.