diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index d574d8c235d..13c7c712612 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -14,6 +14,7 @@ from sympy import IndexedBase from sympy.core.function import Application +from devito.parameters import configuration from devito.exceptions import VisitorException from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) @@ -177,7 +178,7 @@ class CGen(Visitor): def __init__(self, *args, compiler=None, **kwargs): super().__init__(*args, **kwargs) - self._compiler = compiler + self._compiler = compiler or configuration['compiler'] # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 5ec6207977a..7fb18bb9101 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -200,17 +200,26 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h') + lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) headers = {} + # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: + # Constant I headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))} + # Mix arithmetic definitions + dest = compiler.get_jit_dir() + hfile = dest.joinpath('stdcomplex_arith.h') + if not hfile.is_file(): + with open(str(hfile), 'w') as ff: + ff.write(str(_stdcomplex_defs)) + lib += (str(hfile),) for f in FindSymbols().visit(iet): try: if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': (lib,), 'headers': headers} + return iet, {'includes': lib, 'headers': headers} except TypeError: pass @@ -295,3 +304,23 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} + + +_stdcomplex_defs = """ +#include <complex> + +template<typename _Tp, typename _Ti> +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template<typename _Tp, typename _Ti> +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template<typename _Tp, typename _Ti> +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} +""" diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 35038d868f2..9b6b9ab815f 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -39,6 +39,10 @@ def dtype(self): def compiler(self): return self._settings['compiler'] + @property + def cpp(self): + return self.compiler._cpp + def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): return "(%s)" % self._print(item) @@ -101,7 +105,7 @@ def _print_math_func(self, expr, nest=False, known=None): return super()._print_math_func(expr, nest=nest, known=known) dtype = sympy_dtype(expr) - if np.issubdtype(dtype, np.complexfloating): + if np.issubdtype(dtype, np.complexfloating) and not self.cpp: cname = 'c%s' % cname dtype = self.dtype(0).real.dtype.type @@ -255,7 +259,7 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) dtype = self.dtype - if np.issubdtype(dtype, np.complexfloating): + if np.issubdtype(dtype, np.complexfloating) and not self.cpp: func_name = 'c%s' % func_name dtype = self.dtype(0).real.dtype.type if dtype == np.float32: