Skip to content

Commit

Permalink
arch: make sure Device get thread config from host
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 11, 2024
1 parent 7d9ddf6 commit 84765f4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
16 changes: 9 additions & 7 deletions devito/arch/archinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,10 @@ def numa_domains(self):
def threads_per_core(self):
return self.cores_logical // self.cores_physical

@property
def cores_physical_per_numa_domain(self):
return self.cores_physical // self.numa_domains

@property
def memtotal(self):
"""Physical memory size in bytes, or None if unknown."""
Expand Down Expand Up @@ -734,10 +738,6 @@ def numa_domains(self):
warning("NUMA domain count autodetection failed")
return 1

@property
def cores_physical_per_numa_domain(self):
return self.cores_physical // self.numa_domains

@cached_property
def memtotal(self):
return psutil.virtual_memory().total
Expand Down Expand Up @@ -804,13 +804,15 @@ def _detect_isa(self):

class Device(Platform):

def __init__(self, name, cores_logical=1, cores_physical=1, isa='cpp',
def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
max_threads_per_block=1024, max_threads_dimx=1024,
max_threads_dimy=1024, max_threads_dimz=64):
super().__init__(name)

self.cores_logical = cores_logical
self.cores_physical = cores_physical
cpu_info = get_cpu_info()

self.cores_logical = cores_logical or cpu_info['logical']
self.cores_physical = cores_physical or cpu_info['physical']
self.isa = isa

self.max_threads_per_block = max_threads_per_block
Expand Down
11 changes: 10 additions & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
SubDimension, SubDomain, SubDomainSet, TimeFunction,
Operator, configuration, switchconfig, TensorTimeFunction,
Buffer)
from devito.arch import get_gpu_info
from devito.arch import get_gpu_info, get_cpu_info, Device, Cpu64
from devito.exceptions import InvalidArgument
from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols,
retrieve_iteration_tree)
Expand Down Expand Up @@ -47,6 +47,15 @@ def custom_compiler(self):
op.apply(time_M=10)
assert np.all(u.data[1] == 11)

def test_host_threads():
plat = configuration['platform']

assert isinstance(plat, Device)

nth = plat.cores_physical
assert nth == get_cpu_info()['physical']
assert nth == Cpu64().cores_physical


class TestCodeGeneration:

Expand Down

0 comments on commit 84765f4

Please sign in to comment.