From a96a8036648418738a7df65f836812e85ac9ad90 Mon Sep 17 00:00:00 2001 From: FabioLuporini Date: Thu, 21 Sep 2023 07:40:51 +0000 Subject: [PATCH] compiler: Patch unexpansion with double precision --- devito/passes/clusters/derivatives.py | 2 +- tests/test_unexpansion.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 222b665af46..897fe875f8b 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -74,7 +74,7 @@ def _core(expr, c, weights, mapper, sregistry): try: w = weights[k] except KeyError: - w = weights[k] = w0._rebuild(name=name) + w = weights[k] = w0._rebuild(name=name, dtype=expr.dtype) expr = uxreplace(expr, {w0.indexed: w.indexed}) dims = retrieve_dimensions(expr, deep=True) diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index fa076096a3a..fd8c9305f15 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -280,6 +280,20 @@ def test_v1(self): op.cfunction +class TestMisc(object): + + def test_double_precision(self): + grid = Grid(shape=(10, 10, 10), dtype=np.float64) + + u = TimeFunction(name='u', grid=grid, space_order=4) + + eqns = Eq(u.forward, u.laplace + 1.) + + op = Operator(eqns, opt=('advanced', {'expand': False})) + + op.cfunction + + def tti_sa_eqns(grid): t = grid.stepping_dim x, y, z = grid.dimensions