Skip to content

Commit

Permalink
compiler: generate std:complex for cpp compilers
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 28, 2024
1 parent d2dc9ee commit c824b53
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 20 deletions.
35 changes: 22 additions & 13 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def __init__(self, *args, compiler=None, **kwargs):
}
_restrict_keyword = 'restrict'

def _complex_type(self, ctypestr):
if ctypestr not in ('float', 'double'):
return ctypestr
if self._compiler._cpp:
return 'std::complex<%s>' % ctypestr
else:
return '%s _Complex' % ctypestr

def _gen_struct_decl(self, obj, masked=()):
"""
Convert ctypes.Struct -> cgen.Structure.
Expand Down Expand Up @@ -243,10 +251,10 @@ def _gen_value(self, obj, mode=1, masked=()):
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = obj._C_typedata
strtype = self._complex_type(obj._C_typedata)
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
else:
strtype = ctypes_to_cstr(obj._C_ctype)
strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype))
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
if not obj._mem_stack:
Expand Down Expand Up @@ -376,10 +384,11 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = self._complex_type(i._C_typedata)

if f.is_PointerArray:
# lvalue
lvalue = c.Value(i._C_typedata, '**%s' % f.name)
lvalue = c.Value(cstr, '**%s' % f.name)

# rvalue
if isinstance(o.obj, ArrayObject):
Expand All @@ -388,7 +397,7 @@ def visit_PointerCast(self, o):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (i._C_typedata, v)
rvalue = '(%s**) %s' % (cstr, v)

else:
# lvalue
Expand All @@ -399,10 +408,10 @@ def visit_PointerCast(self, o):
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
rshape = '*'
lvalue = c.Value(i._C_typedata, '*%s' % v)
lvalue = c.Value(cstr, '*%s' % v)
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

Expand All @@ -415,30 +424,30 @@ def visit_PointerCast(self, o):
else:
assert False

rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)
rvalue = '(%s %s) %s' % (cstr, rshape, v)

return c.Initializer(lvalue, rvalue)

def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = self._complex_type(i._C_typedata)
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(i._C_typedata,
'(*restrict %s)%s' % (a0.name, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
else:
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
"""
lib = _complex_lib.get(language, 'complex.h')
lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h')

headers = {}
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
if compiler._cpp:
headers = {('_Complex_I', ('1.0fi'))}
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}

for f in FindSymbols().visit(iet):
try:
Expand Down
7 changes: 2 additions & 5 deletions devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ def dtype_to_ctype(dtype):
# Complex data
if np.issubdtype(dtype, np.complexfloating):
rtype = dtype(0).real.__class__
ctname = '%s _Complex' % dtype_to_cstr(rtype)
ctype = dtype_to_ctype(rtype)
r = type(ctname, (ctype,), {})
return r
return dtype_to_ctype(rtype)

try:
return ctypes_vector_mapper[dtype]
Expand Down Expand Up @@ -217,7 +214,7 @@ class c_restrict_void_p(ctypes.c_void_p):
# *** ctypes lowering


def ctypes_to_cstr(ctype, toarray=None):
def ctypes_to_cstr(ctype, toarray=None, cpp=False):
"""Translate ctypes types into C strings."""
if ctype in ctypes_vector_mapper.values():
retval = ctype.__name__
Expand Down

0 comments on commit c824b53

Please sign in to comment.