From 3b0d7e51aa4856eec0147fbf7741c80bab6b3e77 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 31 May 2024 09:58:54 -0400 Subject: [PATCH] compiler: fix internal language specific types and cast wip --- devito/arch/compiler.py | 3 +- devito/operator/operator.py | 4 +- devito/passes/iet/__init__.py | 1 + devito/passes/iet/misc.py | 71 +----------------------------- devito/symbolics/extended_sympy.py | 29 +++++++++++- tests/test_gpu_common.py | 2 - tests/test_operator.py | 2 - 7 files changed, 34 insertions(+), 78 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index e056ff82b7a..5df3891074d 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -248,7 +248,7 @@ def version(self): @property def _complex_ctype(self): """ - Type definition for complex numbers. THese two cases cover 99% of the cases since + Type definition for complex numbers. These two cases cover 99% of the cases since - Hip is now using std::complex https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - Sycl supports std::complex @@ -996,6 +996,7 @@ def __new_with__(self, **kwargs): 'nvc++': NvidiaCompiler, 'nvidia': NvidiaCompiler, 'cuda': CudaCompiler, + 'nvcc': CudaCompiler, 'osx': ClangCompiler, 'intel': OneapiCompiler, 'icx': OneapiCompiler, diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 3e4e8c4e346..0e4b07379a7 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -22,7 +22,7 @@ from devito.parameters import configuration from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, - error_mapper, complex_include) + error_mapper, include_complex) from devito.symbolics import estimate_cost from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, @@ -468,7 +468,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Complex header if needed. Needs to be done before specialization # as some specific cases require complex to be loaded first - complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) + include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index c09db00c9b9..6b4ada0b737 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,3 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa +from .complex import * # noqa diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 53ebe7d3e85..50511b6005f 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -16,7 +16,7 @@ from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros', 'minimize_symbols', 'complex_include'] + 'generate_macros', 'minimize_symbols'] @iet_pass @@ -240,39 +240,6 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'thrust/complex.h'} - - -@iet_pass -def complex_include(iet, language, compiler): - """ - Add headers for complex arithmetic - """ - # Check if there is complex numbers that always take dtype precedence - max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)]) - if not np.issubdtype(max_dtype, np.complexfloating): - return iet, {} - - lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) - - headers = {} - - # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise - if compiler._cpp: - c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) - # Constant I - headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} - # Mix arithmetic definitions - dest = compiler.get_jit_dir() - hfile = dest.joinpath('stdcomplex_arith.h') - if not hfile.is_file(): - with open(str(hfile), 'w') as ff: - ff.write(str(_stdcomplex_defs)) - lib += (str(hfile),) - - return iet, {'includes': lib, 'headers': headers} - - def remove_redundant_moddims(iet): key = lambda d: d.is_Modulo and d.origin is not None mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)] @@ -351,39 +318,3 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} - - -_stdcomplex_defs = """ -#include - -template -std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ - _Tp denom = b.real() * b.real () + b.imag() * b.imag() - return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); -} - -template -std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() / a, b.imag() / a); -} - -template -std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} - -template -std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} -""" diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 7ed801d17a0..03fec7438af 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,6 +7,7 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority +from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -811,6 +812,20 @@ class VOID(Cast): _base_typ = 'void' +class CFLOAT(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('float') + + +class CDOUBLE(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('double') + + class CHARP(CastStar): base = CHAR @@ -827,6 +842,14 @@ class USHORTP(CastStar): base = USHORT +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + cast_mapper = { np.int8: CHAR, np.uint8: UCHAR, @@ -839,6 +862,8 @@ class USHORTP(CastStar): np.float32: FLOAT, # noqa float: DOUBLE, # noqa np.float64: DOUBLE, # noqa + np.complex64: CFLOAT, # noqa + np.complex128: CDOUBLE, # noqa (np.int8, '*'): CHARP, (np.uint8, '*'): UCHARP, @@ -849,7 +874,9 @@ class USHORTP(CastStar): (np.int64, '*'): INTP, # noqa (np.float32, '*'): FLOATP, # noqa (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP # noqa + (np.float64, '*'): DOUBLEP, # noqa + (np.complex64, '*'): CFLOATP, # noqa + (np.complex128, '*'): CDOUBLEP, # noqa } for base_name in ['int', 'float', 'double']: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 1c3f574c8ce..846a9734621 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -74,9 +74,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index c1a88093796..283249aac16 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -647,9 +647,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - # print(op) op() # Check against numpy