diff --git a/devito/dle/__init__.py b/devito/dle/__init__.py index f7f213c87f..e1e708e079 100644 --- a/devito/dle/__init__.py +++ b/devito/dle/__init__.py @@ -1,4 +1,4 @@ from devito.dle.inspection import * # noqa from devito.dle.manipulation import * # noqa from devito.dle.transformer import * # noqa -from devito.dle.backends import init, make_grid # noqa +from devito.dle.backends import YaskGrid, init # noqa diff --git a/devito/dle/backends/yask.py b/devito/dle/backends/yask.py index 2de3d88367..1a690cda4a 100644 --- a/devito/dle/backends/yask.py +++ b/devito/dle/backends/yask.py @@ -14,7 +14,7 @@ from devito.visitors import FindSymbols from devito.tools import as_tuple -__all__ = ['YaskRewriter', 'init', 'make_grid'] +__all__ = ['YaskRewriter', 'init', 'YaskGrid'] YASK = None @@ -55,7 +55,7 @@ def grids(self): mapper[grid.get_name()] = grid return mapper - def setdefault(self, name, buffer=None): + def setdefault(self, name): """ Add and return a new grid ``name``. If a grid ``name`` already exists, then return it without performing any other actions. @@ -68,7 +68,7 @@ def setdefault(self, name, buffer=None): grid = self.hook_soln.new_grid(name, *self.dimensions) # Allocate memory self.hook_soln.prepare_solution() - return YaskGrid(name, grid, self.shape, self.dtype, buffer) + return grid class YaskGrid(object): @@ -84,11 +84,27 @@ class YaskGrid(object): # Force __rOP__ methods (OP={add,mul,...) to get arrays, not scalars, for efficiency __array_priority__ = 1000 - def __init__(self, name, grid, shape, dtype, buffer=None): + def __new__(cls, name, shape, dimensions, dtype, buffer=None): + """ + Create a new YASK Grid and attach it to a "fake" solution. + """ + # Init YASK if not initialized already + init(dimensions, shape, dtype) + # Only create a YaskGrid if the requested grid is a dense one + if tuple(i.name for i in dimensions) == YASK.dimensions: + obj = super(YaskGrid, cls).__new__(cls) + obj.__init__(name, shape, dimensions, dtype, buffer) + return obj + else: + return None + + def __init__(self, name, shape, dimensions, dtype, buffer=None): self.name = name self.shape = shape + self.dimensions = dimensions self.dtype = dtype - self.grid = grid + + self.grid = YASK.setdefault(name) # Always init the grid, at least with 0.0 self[:] = 0.0 if buffer is None else val @@ -161,6 +177,11 @@ def f(self, other): __mod__ = __meta_binop('__mod__') __rmod__ = __meta_binop('__mod__') + @property + def ndpointer(self): + # TODO: see corresponding comment in interfaces.py about CMemory + return self + class YaskRewriter(BasicRewriter): @@ -357,14 +378,6 @@ def init(dimensions, shape, dtype, architecture='hsw', isa='avx2'): dle("YASK backend successfully initialized!") -def make_grid(name, shape, dimensions, dtype): - """ - Create a new YASK Grid and attach it to a "fake" solution. - """ - init(dimensions, shape, dtype) - return YASK.setdefault(name) - - def _force_exit(emsg): """ Handle fatal errors. diff --git a/devito/interfaces.py b/devito/interfaces.py index f2063e32c1..892c69be99 100644 --- a/devito/interfaces.py +++ b/devito/interfaces.py @@ -305,15 +305,25 @@ def _allocate_memory(self): from devito.parameters import configuration if configuration['dle'] == 'yask': - from devito.dle import make_grid - self._data = make_grid(self.name, self.shape, self.indices, self.dtype) + # TODO: Use inheritance + # TODO: Refactor CMemory to be our _data_object, while _data will + # be the YaskGrid itself. + from devito.dle import YaskGrid + debug("Allocating YaskGrid for %s (%s)" % (self.name, str(self.shape))) + + self._data_object = YaskGrid(self.name, self.shape, self.indices, self.dtype) + if self._data_object is not None: + return + + debug("Failed. Reverting to plain allocation...") + + debug("Allocating memory for %s (%s)" % (self.name, str(self.shape))) + + self._data_object = CMemory(self.shape, dtype=self.dtype) + if self.numa: + first_touch(self) else: - debug("Allocating memory for %s (%s)" % (self.name, str(self.shape))) - self._data_object = CMemory(self.shape, dtype=self.dtype) - if self.numa: - first_touch(self) - else: - self.data.fill(0) + self.data.fill(0) @property def data(self):