diff --git a/devito/dle/backends/yask.py b/devito/dle/backends/yask.py index 1a690cda4a..b11025395c 100644 --- a/devito/dle/backends/yask.py +++ b/devito/dle/backends/yask.py @@ -1,4 +1,3 @@ -import numbers import os import sys @@ -10,7 +9,7 @@ from devito.dle import retrieve_iteration_tree from devito.dle.backends import BasicRewriter, dle_pass from devito.exceptions import CompilationError, DLEException -from devito.logger import debug, dle, dle_warning, error +from devito.logger import debug, dle, dle_warning from devito.visitors import FindSymbols from devito.tools import as_tuple @@ -107,7 +106,7 @@ def __init__(self, name, shape, dimensions, dtype, buffer=None): self.grid = YASK.setdefault(name) # Always init the grid, at least with 0.0 - self[:] = 0.0 if buffer is None else val + self[:] = 0.0 if buffer is None else buffer def __getitem__(self, index): # TODO: ATM, no MPI support. diff --git a/tests/test_yask.py b/tests/test_yask.py index 46c920ecaa..0f7f98715f 100644 --- a/tests/test_yask.py +++ b/tests/test_yask.py @@ -39,28 +39,28 @@ def test_data_movement_nD(): u.data[0, 1, 1] = 1. assert u.data[0, 0, 0] == 0. assert u.data[0, 1, 1] == 1. - assert np.all(u.data == u.data[:,:,:]) + assert np.all(u.data == u.data[:, :, :]) assert 1. in u.data[0] assert 1. in u.data[0, 1] # Test negative indices assert u.data[0, -9, -9] == 1. - u.data[6,0,0] = 1. - assert u.data[-4,:,:].sum() == 1. + u.data[6, 0, 0] = 1. + assert u.data[-4, :, :].sum() == 1. # Test setting whole array to given value u.data[:] = 3. assert np.all(u.data == 3.) # Test insertion of single value into block - u.data[5,:,5] = 5. - assert np.all(u.data[5,:,5] == 5.) + u.data[5, :, 5] = 5. + assert np.all(u.data[5, :, 5] == 5.) # Test insertion of block into block block = np.ndarray(shape=(1, 10, 1), dtype=np.float32) block.fill(4.) - u.data[4,:,4] = block - assert np.all(u.data[4,:,4] == block) + u.data[4, :, 4] = block + assert np.all(u.data[4, :, 4] == block) def test_data_arithmetic_nD(): @@ -78,9 +78,9 @@ def test_data_arithmetic_nD(): # Increments and parital increments u.data[:] += 2. assert np.all(u.data == 3.) - u.data[9,:,:] += 1. - assert all(np.all(u.data[i,:,:] == 3.) for i in range(9)) - assert np.all(u.data[9,:,:] == 4.) + u.data[9, :, :] += 1. + assert all(np.all(u.data[i, :, :] == 3.) for i in range(9)) + assert np.all(u.data[9, :, :] == 4.) # Right operations __rOP__ u.data[:] = 1.