Skip to content

Commit

Permalink
compiler: add std::complex arithmetic defs for unsupported types
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 30, 2024
1 parent 5a6b169 commit 014ef2c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
3 changes: 2 additions & 1 deletion devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sympy import IndexedBase
from sympy.core.function import Application

from devito.parameters import configuration
from devito.exceptions import VisitorException
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
Call, Lambda, BlankLine, Section, ListMajor)
Expand Down Expand Up @@ -177,7 +178,7 @@ class CGen(Visitor):

def __init__(self, *args, compiler=None, **kwargs):
super().__init__(*args, **kwargs)
self._compiler = compiler
self._compiler = compiler or configuration['compiler']

# The following mappers may be customized by subclasses (that is,
# backend-specific CGen-erators)
Expand Down
33 changes: 31 additions & 2 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,26 @@ def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
"""
lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h')
lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)

headers = {}

# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
if compiler._cpp:
# Constant I
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}
# Mix arithmetic definitions
dest = compiler.get_jit_dir()
hfile = dest.joinpath('stdcomplex_arith.h')
if not hfile.is_file():
with open(str(hfile), 'w') as ff:
ff.write(str(_stdcomplex_defs))
lib += (str(hfile),)

for f in FindSymbols().visit(iet):
try:
if np.issubdtype(f.dtype, np.complexfloating):
return iet, {'includes': (lib,), 'headers': headers}
return iet, {'includes': lib, 'headers': headers}
except TypeError:
pass

Expand Down Expand Up @@ -295,3 +304,23 @@ def _rename_subdims(target, dimensions):
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}


_stdcomplex_defs = """
#include <complex>
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() * a, b.imag() * a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() / a, b.imag() / a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() + a, b.imag());
}
"""
8 changes: 6 additions & 2 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def dtype(self):
def compiler(self):
return self._settings['compiler']

@property
def cpp(self):
return self.compiler._cpp

def parenthesize(self, item, level, strict=False):
if isinstance(item, BooleanFunction):
return "(%s)" % self._print(item)
Expand Down Expand Up @@ -101,7 +105,7 @@ def _print_math_func(self, expr, nest=False, known=None):
return super()._print_math_func(expr, nest=nest, known=known)

dtype = sympy_dtype(expr)
if np.issubdtype(dtype, np.complexfloating):
if np.issubdtype(dtype, np.complexfloating) and not self.cpp:
cname = 'c%s' % cname
dtype = self.dtype(0).real.dtype.type

Expand Down Expand Up @@ -255,7 +259,7 @@ def _print_ComponentAccess(self, expr):
def _print_TrigonometricFunction(self, expr):
func_name = str(expr.func)
dtype = self.dtype
if np.issubdtype(dtype, np.complexfloating):
if np.issubdtype(dtype, np.complexfloating) and not self.cpp:
func_name = 'c%s' % func_name
dtype = self.dtype(0).real.dtype.type
if dtype == np.float32:
Expand Down

0 comments on commit 014ef2c

Please sign in to comment.