diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 0d35fa5aeb3..cd4f9d45ad9 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -25,7 +25,8 @@ 'INTEL64', 'SNB', 'IVB', 'HSW', 'BDW', 'KNL', 'KNL7210', 'SKX', 'KLX', 'CLX', 'CLK', 'SPR', # ARM CPUs - 'AMD', 'ARM', 'AppleArm', 'M1', 'M2', 'M3', 'GRAVITON', + 'AMD', 'ARM', 'AppleArm', 'M1', 'M2', 'M3', + 'Graviton', 'GRAVITON2', 'GRAVITON3', 'GRAVITON4', # Other legacy CPUs 'POWER8', 'POWER9', # Generic GPUs @@ -764,6 +765,20 @@ def march(self): return min(mx, 'm2') +class Graviton(Arm): + + @property + def version(self): + return int(self.name.split('graviton')[-1]) + + @cached_property + def march(self): + if self.version >= 4: + return 'neoverse-n2' + else: + return 'neoverse-n1' + + class Amd(Cpu64): known_isas = ('cpp', 'sse', 'avx', 'avx2') @@ -912,7 +927,9 @@ def march(cls): SPR = IntelGoldenCove('spr') # Sapphire Rapids ARM = Arm('arm') -GRAVITON = Arm('graviton') +GRAVITON2 = Graviton('graviton2') +GRAVITON3 = Graviton('graviton3') +GRAVITON4 = Graviton('graviton4') M1 = AppleArm('m1') M2 = AppleArm('m2') M3 = AppleArm('m3') diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 9cd94ed597d..aaec25963ba 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -13,7 +13,7 @@ from codepy.toolchain import (GCCToolchain, call_capture_output as _call_capture_output) -from devito.arch import (AMDGPUX, Cpu64, AppleArm, NVIDIAX, POWER8, POWER9, GRAVITON, +from devito.arch import (AMDGPUX, Cpu64, AppleArm, NVIDIAX, POWER8, POWER9, Graviton, IntelDevice, get_nvidia_cc, check_cuda_runtime, get_m1_llvm_path) from devito.exceptions import CompilationError @@ -434,6 +434,10 @@ def __init_finalize__(self, **kwargs): if platform in [POWER8, POWER9]: # -march isn't supported on power architectures, is -mtune needed? self.cflags = ['-mcpu=native'] + self.cflags + elif platform is Graviton: + # Graviton flag + mx = platform.march + self.cflags = ['-mcpu=%s' % mx] + self.cflags else: self.cflags = ['-march=native'] + self.cflags @@ -462,8 +466,9 @@ def __init_finalize__(self, **kwargs): platform = kwargs.pop('platform', configuration['platform']) # Graviton flag - if platform is GRAVITON: - self.cflags += ['-mcpu=neoverse-n1'] + if platform is Graviton: + mx = platform.march + self.cflags += ['-mcpu=%s' % mx] class ClangCompiler(Compiler): @@ -962,37 +967,45 @@ def __new_with__(self, **kwargs): return super().__new_with__(base=self._base, **kwargs) -compiler_registry = { - 'custom': CustomCompiler, - 'gnu': GNUCompiler, - 'gcc': GNUCompiler, - 'arm': ArmCompiler, - 'clang': ClangCompiler, - 'cray': CrayCompiler, - 'aomp': AOMPCompiler, - 'amdclang': AOMPCompiler, - 'hip': HipCompiler, - 'pgcc': PGICompiler, - 'pgi': PGICompiler, - 'nvc': NvidiaCompiler, - 'nvc++': NvidiaCompiler, - 'nvidia': NvidiaCompiler, - 'cuda': CudaCompiler, - 'osx': ClangCompiler, - 'intel': OneapiCompiler, - 'icx': OneapiCompiler, - 'icpx': OneapiCompiler, - 'sycl': SyclCompiler, - 'icc': IntelCompiler, - 'icpc': IntelCompiler, - 'intel-knl': IntelKNLCompiler, - 'knl': IntelKNLCompiler, - 'dpcpp': DPCPPCompiler, -} -""" -Registry dict for deriving Compiler classes according to the environment variable -DEVITO_ARCH. Developers should add new compiler classes here. -""" -compiler_registry.update({'gcc-%s' % i: partial(GNUCompiler, suffix=i) - for i in ['4.9', '5', '6', '7', '8', '9', '10', - '11', '12', '13']}) +class CompilerRegistry(dict): + """ + Registry dict for deriving Compiler classes according to the environment variable + DEVITO_ARCH. Developers should add new compiler classes here. + """ + + _compiler_registry = { + 'custom': CustomCompiler, + 'gnu': GNUCompiler, + 'gcc': GNUCompiler, + 'arm': ArmCompiler, + 'clang': ClangCompiler, + 'cray': CrayCompiler, + 'aomp': AOMPCompiler, + 'amdclang': AOMPCompiler, + 'hip': HipCompiler, + 'pgcc': PGICompiler, + 'pgi': PGICompiler, + 'nvc': NvidiaCompiler, + 'nvc++': NvidiaCompiler, + 'nvidia': NvidiaCompiler, + 'cuda': CudaCompiler, + 'osx': ClangCompiler, + 'intel': OneapiCompiler, + 'icx': OneapiCompiler, + 'icpx': OneapiCompiler, + 'sycl': SyclCompiler, + 'icc': IntelCompiler, + 'icpc': IntelCompiler, + 'intel-knl': IntelKNLCompiler, + 'knl': IntelKNLCompiler, + 'dpcpp': DPCPPCompiler, + } + + def __getitem__(self, key): + if key.startswith('gcc-'): + i = key.split('-')[1] + return partial(GNUCompiler, suffix=i) + return self._compiler_registry[key] + + +compiler_registry = CompilerRegistry()