Skip to content

Commit

Permalink
compiler: fix internal language specific types and cast
Browse files Browse the repository at this point in the history
wip
  • Loading branch information
mloubout committed Jun 21, 2024
1 parent abf483c commit 3b0d7e5
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 78 deletions.
3 changes: 2 additions & 1 deletion devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def version(self):
@property
def _complex_ctype(self):
"""
Type definition for complex numbers. THese two cases cover 99% of the cases since
Type definition for complex numbers. These two cases cover 99% of the cases since
- Hip is now using std::complex
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
- Sycl supports std::complex
Expand Down Expand Up @@ -996,6 +996,7 @@ def __new_with__(self, **kwargs):
'nvc++': NvidiaCompiler,
'nvidia': NvidiaCompiler,
'cuda': CudaCompiler,
'nvcc': CudaCompiler,
'osx': ClangCompiler,
'intel': OneapiCompiler,
'icx': OneapiCompiler,
Expand Down
4 changes: 2 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, complex_include)
error_mapper, include_complex)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -468,7 +468,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Complex header if needed. Needs to be done before specialization
# as some specific cases require complex to be loaded first
complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler'])
include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler'])

# Specialize
graph = cls._specialize_iet(graph, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions devito/passes/iet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .instrument import * # noqa
from .languages import * # noqa
from .errors import * # noqa
from .complex import * # noqa
71 changes: 1 addition & 70 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from devito.types import FIndexed

__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',
'generate_macros', 'minimize_symbols', 'complex_include']
'generate_macros', 'minimize_symbols']


@iet_pass
Expand Down Expand Up @@ -240,39 +240,6 @@ def minimize_symbols(iet):
return iet, {}


_complex_lib = {'cuda': 'thrust/complex.h'}


@iet_pass
def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
"""
# Check if there is complex numbers that always take dtype precedence
max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)])
if not np.issubdtype(max_dtype, np.complexfloating):
return iet, {}

lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)

headers = {}

# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
if compiler._cpp:
c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type)
# Constant I
headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))}
# Mix arithmetic definitions
dest = compiler.get_jit_dir()
hfile = dest.joinpath('stdcomplex_arith.h')
if not hfile.is_file():
with open(str(hfile), 'w') as ff:
ff.write(str(_stdcomplex_defs))
lib += (str(hfile),)

return iet, {'includes': lib, 'headers': headers}


def remove_redundant_moddims(iet):
key = lambda d: d.is_Modulo and d.origin is not None
mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)]
Expand Down Expand Up @@ -351,39 +318,3 @@ def _rename_subdims(target, dimensions):
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}


_stdcomplex_defs = """
#include <complex>
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() * a, b.imag() * a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() * a, b.imag() * a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
_Tp denom = b.real() * b.real () + b.imag() * b.imag()
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() / a, b.imag() / a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() + a, b.imag());
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() + a, b.imag());
}
"""
29 changes: 28 additions & 1 deletion devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sympy import Expr, Function, Number, Tuple, sympify
from sympy.core.decorators import call_highest_priority

from devito import configuration
from devito.finite_differences.elementary import Min, Max
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
float3, float4, double2, double3, double4, int2, int3,
Expand Down Expand Up @@ -811,6 +812,20 @@ class VOID(Cast):
_base_typ = 'void'


class CFLOAT(Cast):

@property
def _base_typ(self):
return configuration['compiler']._complex_ctype('float')


class CDOUBLE(Cast):

@property
def _base_typ(self):
return configuration['compiler']._complex_ctype('double')


class CHARP(CastStar):
base = CHAR

Expand All @@ -827,6 +842,14 @@ class USHORTP(CastStar):
base = USHORT


class CFLOATP(CastStar):
base = CFLOAT


class CDOUBLEP(CastStar):
base = CDOUBLE


cast_mapper = {
np.int8: CHAR,
np.uint8: UCHAR,
Expand All @@ -839,6 +862,8 @@ class USHORTP(CastStar):
np.float32: FLOAT, # noqa
float: DOUBLE, # noqa
np.float64: DOUBLE, # noqa
np.complex64: CFLOAT, # noqa
np.complex128: CDOUBLE, # noqa

(np.int8, '*'): CHARP,
(np.uint8, '*'): UCHARP,
Expand All @@ -849,7 +874,9 @@ class USHORTP(CastStar):
(np.int64, '*'): INTP, # noqa
(np.float32, '*'): FLOATP, # noqa
(float, '*'): DOUBLEP, # noqa
(np.float64, '*'): DOUBLEP # noqa
(np.float64, '*'): DOUBLEP, # noqa
(np.complex64, '*'): CFLOATP, # noqa
(np.complex128, '*'): CDOUBLEP, # noqa
}

for base_name in ['int', 'float', 'double']:
Expand Down
2 changes: 0 additions & 2 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def test_complex(self, dtype):
u = Function(name="u", grid=grid, dtype=dtype)

eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
print(op)
op()

# Check against numpy
Expand Down
2 changes: 0 additions & 2 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,7 @@ def test_complex(self, dtype):
u = Function(name="u", grid=grid, dtype=dtype)

eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
# print(op)
op()

# Check against numpy
Expand Down

0 comments on commit 3b0d7e5

Please sign in to comment.