Skip to content

Commit

Permalink
Merge pull request #2406 from devitocodes/tens-fix
Browse files Browse the repository at this point in the history
misc: minor miscelanous fixes
  • Loading branch information
mloubout authored Jul 12, 2024
2 parents 0124871 + de4837d commit 3024584
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
2 changes: 1 addition & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def _eval_matrix_mul(self, other):
new_mat[i] = sum(vec)

# Get new class and return product
newcls = self.classof_prod(other, new_mat)
newcls = self.classof_prod(other, other.cols)
return newcls._new(self.rows, other.cols, new_mat, copy=False)


Expand Down
6 changes: 4 additions & 2 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,10 @@ def _C_get_field(self, region, dim, side=None):

def _halo_exchange(self):
"""Perform the halo exchange with the neighboring processes."""
if not MPI.Is_initialized() or MPI.COMM_WORLD.size == 1 or \
not configuration['mpi']:
if not MPI.Is_initialized() or \
MPI.COMM_WORLD.size == 1 or \
not configuration['mpi'] or \
self.grid is None:
# Nothing to do
return
if MPI.COMM_WORLD.size > 1 and self._distributor is None:
Expand Down
7 changes: 2 additions & 5 deletions devito/types/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,8 @@ def new_from_mat(self, mat):
func = tens_func(self)
return func._new(self.rows, self.cols, mat)

def classof_prod(self, other, mat):
try:
is_mat = len(mat[0]) > 1
except TypeError:
is_mat = False
def classof_prod(self, other, cols):
is_mat = cols > 1
is_time = (getattr(self, '_is_TimeDependent', False) or
getattr(other, '_is_TimeDependent', False))
return mat_time_dict[(is_time, is_mat)]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def test_tensor_matmul(func1, func2, out_type):


@pytest.mark.parametrize('func1, func2, out_type', [
(VectorFunction, TensorFunction, VectorFunction),
(VectorTimeFunction, TensorFunction, VectorTimeFunction),
(VectorFunction, TensorTimeFunction, VectorTimeFunction),
(VectorTimeFunction, TensorTimeFunction, VectorTimeFunction)])
(VectorFunction, TensorFunction, TensorFunction),
(VectorTimeFunction, TensorFunction, TensorTimeFunction),
(VectorFunction, TensorTimeFunction, TensorTimeFunction),
(VectorTimeFunction, TensorTimeFunction, TensorTimeFunction)])
def test_tensor_matmul_T(func1, func2, out_type):
grid = Grid(tuple([5]*3))
f1 = func1(name="f1", grid=grid)
Expand Down

0 comments on commit 3024584

Please sign in to comment.