Skip to content

Commit

Permalink
compiler: rework dtype lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jun 21, 2024
1 parent 3b0d7e5 commit baaa76b
Show file tree
Hide file tree
Showing 18 changed files with 330 additions and 175 deletions.
18 changes: 0 additions & 18 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,6 @@ def version(self):

return version

@property
def _complex_ctype(self):
"""
Type definition for complex numbers. These two cases cover 99% of the cases since
- Hip is now using std::complex
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
- Sycl supports std::complex
- C's _Complex is part of C99
"""
if self._cpp:
return lambda dtype: 'std::complex<%s>' % str(dtype)
else:
return lambda dtype: '%s _Complex' % str(dtype)

def get_version(self):
result, stdout, stderr = call_capture_output((self.cc, "--version"))
if result != 0:
Expand Down Expand Up @@ -713,10 +699,6 @@ def __lookup_cmds__(self):
self.MPICC = 'nvcc'
self.MPICXX = 'nvcc'

@property
def _complex_ctype(self):
return lambda dtype: 'thrust::complex<%s>' % str(dtype)


class HipCompiler(Compiler):

Expand Down
2 changes: 1 addition & 1 deletion devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
'blocking', 'tasking', 'streaming', 'factorize', 'fission', 'fuse', 'lift',
'cire-sops', 'cse', 'opt-pows', 'topofuse',
# IET
'orchestrate', 'pthreadify', 'parallel', 'mpi', 'linearize', 'prodders'
'orchestrate', 'pthreadify', 'parallel', 'mpi', 'linearize', 'prodders', 'dtypes'
)
_known_passes_disabled = ('denormals', 'simd')
assert not (set(_known_passes) & set(_known_passes_disabled))
Expand Down
6 changes: 1 addition & 5 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, include_complex)
error_mapper)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -466,10 +466,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Complex header if needed. Needs to be done before specialization
# as some specific cases require complex to be loaded first
include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler'])

# Specialize
graph = cls._specialize_iet(graph, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from .instrument import * # noqa
from .languages import * # noqa
from .errors import * # noqa
from .complex import * # noqa
from .dtypes import * # noqa
12 changes: 11 additions & 1 deletion devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_complex
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -73,10 +74,12 @@ class DataManager:
The language used to express data allocations, deletions, and host-device transfers.
"""

def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
def __init__(self, rcompile=None, sregistry=None, platform=None,
compiler=None, **kwargs):
self.rcompile = rcompile
self.sregistry = sregistry
self.platform = platform
self.compiler = compiler

def _alloc_object_on_low_lat_mem(self, site, obj, storage):
"""
Expand Down Expand Up @@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs):

return iet, {}

@iet_pass
def make_langtypes(self, iet):
iet, metadata = lower_complex(iet, self.lang, self.compiler)
return iet, metadata

def process(self, graph):
"""
Apply the `place_definitions` and `place_casts` passes.
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.make_langtypes(graph)


class DeviceAwareDataManager(DataManager):
Expand Down Expand Up @@ -564,6 +573,7 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.make_langtypes(graph)


def make_zero_init(obj):
Expand Down
55 changes: 55 additions & 0 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np

from devito.ir import FindSymbols, Uxreplace

__all__ = ['lower_complex']


def lower_complex(iet, lang, compiler):
"""
Add headers for complex arithmetic
"""
# Check if there is complex numbers that always take dtype precedence
types = [f.dtype for f in FindSymbols().visit(iet) if isinstance(f.dtype, np.dtype)]
if not types:
return iet, {}

max_dtype = np.result_type(*types)
if not np.issubdtype(max_dtype, np.complexfloating):
return iet, {}

lib = (lang['header-complex'],)
headers = lang.get('I-def')

# Some languges such as c++11 need some extra arithmetic definitions
if lang.get('def-complex'):
dest = compiler.get_jit_dir()
hfile = dest.joinpath('complex_arith.h')
with open(str(hfile), 'w') as ff:
ff.write(str(lang['def-complex']))
lib += (str(hfile),)

iet = _complex_dtypes(iet, lang)

return iet, {'includes': lib, 'headers': headers}


def _complex_dtypes(iet, lang):
"""
Lower dtypes to language specific types
"""
mapper = {}

for s in FindSymbols('indexeds').visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

for s in FindSymbols().visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

body = Uxreplace(mapper).visit(iet.body)
params = Uxreplace(mapper).visit(iet.parameters)
iet = iet._rebuild(body=body, parameters=params)

return iet
11 changes: 11 additions & 0 deletions devito/passes/iet/langbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __getitem__(self, k):
raise NotImplementedError("Missing required mapping for `%s`" % k)
return self.mapper[k]

def get(self, k):
return self.mapper.get(k)


class LangBB(metaclass=LangMeta):

Expand Down Expand Up @@ -200,6 +203,14 @@ def initialize(self, iet, options=None):
"""
return iet, {}

@iet_pass
def make_langtypes(self, iet):
"""
An `iet_pass` which transforms an IET such that the target language
types are introduced.
"""
return iet, {}

@property
def Region(self):
return self.lang.Region
Expand Down
13 changes: 12 additions & 1 deletion devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.tools import CustomNpType

__all__ = ['CBB', 'CDataManager', 'COrchestrator']


CCFloat = CustomNpType('_Complex float', np.complex64)
CCDouble = CustomNpType('_Complex double', np.complex128)


class CBB(LangBB):

mapper = {
Expand All @@ -19,7 +26,11 @@ class CBB(LangBB):
'host-free-pin': lambda i:
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k))
Call('memcpy', (i, j, k)),
# Complex
'header-complex': 'complex.h',
'types': {np.complex128: CCDouble, np.complex64: CCFloat},
'I-def': None
}


Expand Down
68 changes: 68 additions & 0 deletions devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np

from devito.ir import Call
from devito.passes.iet.langbase import LangBB
from devito.tools import CustomNpType

__all__ = ['CXXBB']


std_arith = """
#include <complex>
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() * a, b.imag() * a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() * a, b.imag() * a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
_Tp denom = b.real() * b.real () + b.imag() * b.imag()
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() / a, b.imag() / a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() + a, b.imag());
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() + a, b.imag());
}
"""

CXXCFloat = CustomNpType('std::complex', np.complex64, template='float')
CXXCDouble = CustomNpType('std::complex', np.complex128, template='double')


class CXXBB(LangBB):

mapper = {
'header-memcpy': 'string.h',
'host-alloc': lambda i, j, k:
Call('posix_memalign', (i, j, k)),
'host-alloc-pin': lambda i, j, k:
Call('posix_memalign', (i, j, k)),
'host-free': lambda i:
Call('free', (i,)),
'host-free-pin': lambda i:
Call('free', (i,)),
'alloc-global-symbol': lambda i, j, k:
Call('memcpy', (i, j, k)),
# Complex
'header-complex': '<complex>',
'I-def': (('_Complex_I', ('std::complex<float>(0.0, 1.0)')),),
'def-complex': std_arith,
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat},
}
5 changes: 3 additions & 2 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB,
PragmaIteration, PragmaTransfer)
from devito.passes.iet.languages.C import CBB
from devito.passes.iet.languages.CXX import CXXBB
from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration
from devito.symbolics import FieldFromPointer, Macro, cast_mapper
from devito.tools import filter_ordered, UnboundTuple
Expand Down Expand Up @@ -118,7 +118,8 @@ class AccBB(PragmaLangBB):
'device-free': lambda i, *a:
Call('acc_free', (i,))
}
mapper.update(CBB.mapper)

mapper.update(CXXBB.mapper)

Region = OmpRegion
HostIteration = OmpIteration # Host parallelism still goes via OpenMP
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper,
ccode)
from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr
from devito.tools import Bunch, as_mapper, filter_ordered, split
from devito.types import FIndexed

__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',
Expand Down
1 change: 1 addition & 0 deletions devito/symbolics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from devito.symbolics.extended_sympy import * # noqa
from devito.symbolics.extended_dtypes import * # noqa
from devito.symbolics.queries import * # noqa
from devito.symbolics.search import * # noqa
from devito.symbolics.printer import * # noqa
Expand Down
Loading

0 comments on commit baaa76b

Please sign in to comment.