Skip to content

Commit

Permalink
compiler: generate std:complex for cpp compilers
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 30, 2024
1 parent 841e4e7 commit 5a6b169
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 25 deletions.
43 changes: 30 additions & 13 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ctypes

import cgen as c
import numpy as np
from sympy import IndexedBase
from sympy.core.function import Application

Expand Down Expand Up @@ -188,6 +189,21 @@ def __init__(self, *args, compiler=None, **kwargs):
}
_restrict_keyword = 'restrict'

def _complex_type(self, ctypestr, dtype):
# Not complex
try:
if not np.issubdtype(dtype, np.complexfloating):
return ctypestr
except TypeError:
return ctypestr
# Complex only supported for float and double
if ctypestr not in ('float', 'double'):
return ctypestr
if self._compiler._cpp:
return 'std::complex<%s>' % ctypestr
else:
return '%s _Complex' % ctypestr

def _gen_struct_decl(self, obj, masked=()):
"""
Convert ctypes.Struct -> cgen.Structure.
Expand Down Expand Up @@ -243,10 +259,10 @@ def _gen_value(self, obj, mode=1, masked=()):
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = obj._C_typedata
strtype = self._complex_type(obj._C_typedata, obj.dtype)
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
else:
strtype = ctypes_to_cstr(obj._C_ctype)
strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
if not obj._mem_stack:
Expand Down Expand Up @@ -376,10 +392,11 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = self._complex_type(i._C_typedata, i.dtype)

if f.is_PointerArray:
# lvalue
lvalue = c.Value(i._C_typedata, '**%s' % f.name)
lvalue = c.Value(cstr, '**%s' % f.name)

# rvalue
if isinstance(o.obj, ArrayObject):
Expand All @@ -388,7 +405,7 @@ def visit_PointerCast(self, o):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (i._C_typedata, v)
rvalue = '(%s**) %s' % (cstr, v)

else:
# lvalue
Expand All @@ -399,10 +416,10 @@ def visit_PointerCast(self, o):
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
rshape = '*'
lvalue = c.Value(i._C_typedata, '*%s' % v)
lvalue = c.Value(cstr, '*%s' % v)
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

Expand All @@ -415,30 +432,30 @@ def visit_PointerCast(self, o):
else:
assert False

rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)
rvalue = '(%s %s) %s' % (cstr, rshape, v)

return c.Initializer(lvalue, rvalue)

def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = self._complex_type(i._C_typedata, i.dtype)
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(i._C_typedata,
'(*restrict %s)%s' % (a0.name, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
else:
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
"""
lib = _complex_lib.get(language, 'complex.h')
lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h')

headers = {}
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
if compiler._cpp:
headers = {('_Complex_I', ('1.0fi'))}
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}

for f in FindSymbols().visit(iet):
try:
Expand Down
7 changes: 4 additions & 3 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def _print_math_func(self, expr, nest=False, known=None):
dtype = sympy_dtype(expr)
if np.issubdtype(dtype, np.complexfloating):
cname = 'c%s' % cname
dtype = self.dtype(0).real.dtype
dtype = self.dtype(0).real.dtype.type

if dtype is np.float32:
cname = '%sf' % cname
args = ', '.join((self._print(arg) for arg in expr.args))
Expand Down Expand Up @@ -194,7 +195,7 @@ def _print_Float(self, expr):
elif rv.startswith('.0'):
rv = '0.' + rv[2:]

if self.dtype == np.float32:
if self.dtype == np.float32 or self.dtype == np.complex64:
rv = rv + 'F'

return rv
Expand Down Expand Up @@ -256,7 +257,7 @@ def _print_TrigonometricFunction(self, expr):
dtype = self.dtype
if np.issubdtype(dtype, np.complexfloating):
func_name = 'c%s' % func_name
dtype = self.dtype(0).real.dtype
dtype = self.dtype(0).real.dtype.type
if dtype == np.float32:
func_name = '%sf' % func_name
return '%s(%s)' % (func_name, self._print(*expr.args))
Expand Down
7 changes: 2 additions & 5 deletions devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ def dtype_to_ctype(dtype):
# Complex data
if np.issubdtype(dtype, np.complexfloating):
rtype = dtype(0).real.__class__
ctname = '%s _Complex' % dtype_to_cstr(rtype)
ctype = dtype_to_ctype(rtype)
r = type(ctname, (ctype,), {})
return r
return dtype_to_ctype(rtype)

try:
return ctypes_vector_mapper[dtype]
Expand Down Expand Up @@ -217,7 +214,7 @@ class c_restrict_void_p(ctypes.c_void_p):
# *** ctypes lowering


def ctypes_to_cstr(ctype, toarray=None):
def ctypes_to_cstr(ctype, toarray=None, cpp=False):
"""Translate ctypes types into C strings."""
if ctype in ctypes_vector_mapper.values():
retval = ctype.__name__
Expand Down
3 changes: 2 additions & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import numpy as np
import sympy
import scipy.sparse

from conftest import assert_structure
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_complex(self):
# Float32 complex is called complex64 in numpy
u = Function(name="u", grid=grid, dtype=np.complex64)

eq = Eq(u, x + 1j*y + exp(1j + x.spacing))
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
op()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def test_complex(self):
# Float32 complex is called complex64 in numpy
u = Function(name="u", grid=grid, dtype=np.complex64)

eq = Eq(u, x + 1j*y + exp(1j + x.spacing))
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
op()
Expand Down

0 comments on commit 5a6b169

Please sign in to comment.