Skip to content

Commit

Permalink
arch: revamp compiler init for more robust custom compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Nov 1, 2023
1 parent 2089885 commit 31a8728
Showing 1 changed file with 48 additions and 52 deletions.
100 changes: 48 additions & 52 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,20 @@ def __init__(self):
"""

fields = {'cc', 'ld'}
_cpp = False

def __init__(self, **kwargs):
super(Compiler, self).__init__(**kwargs)
super().__init__(**kwargs)

self.__lookup_cmds__()

self.suffix = kwargs.get('suffix')
if not kwargs.get('mpi'):
self.cc = self.CC if kwargs.get('cpp', False) is False else self.CXX
self.cc = self.CC if self._cpp is False else self.CXX
self.cc = self.cc if self.suffix is None else ('%s-%s' %
(self.cc, self.suffix))
else:
self.cc = self.MPICC if kwargs.get('cpp', False) is False else self.MPICXX
self.cc = self.MPICC if self._cpp is False else self.MPICXX
self.ld = self.cc # Wanted by the superclass

self.cflags = ['-O3', '-g', '-fPIC', '-Wall', '-std=c99']
Expand All @@ -196,7 +197,7 @@ def __init__(self, **kwargs):
self.defines = []
self.undefines = []

self.src_ext = 'c' if kwargs.get('cpp', False) is False else 'cpp'
self.src_ext = 'c' if self._cpp is False else 'cpp'

if platform.system() == "Linux":
self.so_ext = '.so'
Expand All @@ -216,6 +217,11 @@ def __init__(self, **kwargs):
# Knowing the version may still be useful to pick supported flags
self.version = sniff_compiler_version(self.CC)

self.__init_finalize__(**kwargs)

def __init_finalize__(self, **kwargs):
pass

def __new_with__(self, **kwargs):
"""
Create a new Compiler from an existing one, inherenting from it
Expand Down Expand Up @@ -394,9 +400,7 @@ def add_ldflags(self, flags):

class GNUCompiler(Compiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __init_finalize__(self, **kwargs):
platform = kwargs.pop('platform', configuration['platform'])

self.cflags += ['-Wno-unused-result',
Expand Down Expand Up @@ -443,9 +447,7 @@ def __lookup_cmds__(self):

class ArmCompiler(GNUCompiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __init_finalize__(self, **kwargs):
platform = kwargs.pop('platform', configuration['platform'])

# Graviton flag
Expand All @@ -455,8 +457,7 @@ def __init__(self, *args, **kwargs):

class ClangCompiler(Compiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init_finalize__(self, **kwargs):

self.cflags += ['-Wno-unused-result', '-Wno-unused-variable']
if not configuration['safe-math']:
Expand Down Expand Up @@ -522,8 +523,7 @@ class AOMPCompiler(Compiler):

"""AMD's fork of Clang for OpenMP offloading on both AMD and NVidia cards."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init_finalize__(self, **kwargs):

language = kwargs.pop('language', configuration['language'])
platform = kwargs.pop('platform', configuration['platform'])
Expand Down Expand Up @@ -556,8 +556,7 @@ def __lookup_cmds__(self):

class DPCPPCompiler(Compiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init_finalize__(self, **kwargs):

self.cflags += ['-qopenmp', '-fopenmp-targets=spir64']

Expand All @@ -572,8 +571,7 @@ def __lookup_cmds__(self):

class PGICompiler(Compiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, cpp=True, **kwargs)
def __init_finalize__(self, **kwargs):

self.cflags.remove('-std=c99')
self.cflags.remove('-O3')
Expand Down Expand Up @@ -618,8 +616,9 @@ def __lookup_cmds__(self):

class CudaCompiler(Compiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, cpp=True, **kwargs)
_cpp = True

def __init_finalize__(self, **kwargs):

self.cflags.remove('-std=c99')
self.cflags.remove('-Wall')
Expand Down Expand Up @@ -683,8 +682,9 @@ def __lookup_cmds__(self):

class HipCompiler(Compiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, cpp=True, **kwargs)
_cpp = True

def __init_finalize__(self, **kwargs):

self.cflags.remove('-std=c99')
self.cflags.remove('-Wall')
Expand Down Expand Up @@ -712,8 +712,7 @@ def __lookup_cmds__(self):

class IntelCompiler(Compiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init_finalize__(self, **kwargs):

platform = kwargs.pop('platform', configuration['platform'])
language = kwargs.pop('language', configuration['language'])
Expand Down Expand Up @@ -771,8 +770,7 @@ def __lookup_cmds__(self):

class IntelKNLCompiler(IntelCompiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init_finalize__(self, **kwargs):

language = kwargs.pop('language', configuration['language'])

Expand All @@ -784,8 +782,7 @@ def __init__(self, *args, **kwargs):

class OneapiCompiler(IntelCompiler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init_finalize__(self, **kwargs):

platform = kwargs.pop('platform', configuration['platform'])
language = kwargs.pop('language', configuration['language'])
Expand Down Expand Up @@ -841,38 +838,37 @@ def __new__(cls, *args, **kwargs):
platform = kwargs.pop('platform', configuration['platform'])
language = kwargs.pop('language', configuration['language'])

if any(i in environ for i in ['CC', 'CXX', 'CFLAGS', 'LDFLAGS']):
obj = super().__new__(cls)
obj.__init__(*args, **kwargs)
return obj
elif platform is M1:
return ClangCompiler(*args, **kwargs)
if platform is M1:
_base = ClangCompiler
elif platform is INTELGPUX:
return OneapiCompiler(*args, **kwargs)
_base = OneapiCompiler
elif platform is NVIDIAX:
if language == 'cuda':
return CudaCompiler(*args, **kwargs)
_base = CudaCompiler
else:
return NvidiaCompiler(*args, **kwargs)
_base = NvidiaCompiler
elif platform is AMDGPUX:
if language == 'hip':
return HipCompiler(*args, **kwargs)
_base = HipCompiler
else:
return AOMPCompiler(*args, **kwargs)
_base = AOMPCompiler
else:
return GNUCompiler(*args, **kwargs)

def __init__(self, *args, **kwargs):
super(CustomCompiler, self).__init__(*args, **kwargs)

default = '-O3 -g -march=native -fPIC -Wall -std=c99'
self.cflags = environ.get('CFLAGS', default).split(' ')
self.ldflags = environ.get('LDFLAGS', '-shared').split(' ')

language = kwargs.pop('language', configuration['language'])

if language == 'openmp':
self.ldflags += environ.get('OMP_LDFLAGS', '-fopenmp').split(' ')
_base = GNUCompiler

obj = super().__new__(cls)
# Keep base to initialize accordingly
obj._base = _base

return obj

def __init_finalize__(self, **kwargs):
self._base.__init_finalize__(self, **kwargs)
# Update cflags
extrac = environ.get('CFLAGS', '').split(' ')
self.cflags = filter_ordered(self.cflags + extrac)
# Update ldflags
extrald = environ.get('LDFLAGS', '').split(' ')
self.ldflags = filter_ordered(self.ldflags + extrald)

def __lookup_cmds__(self):
self.CC = environ.get('CC', 'gcc')
Expand Down

0 comments on commit 31a8728

Please sign in to comment.