Skip to content

Commit

Permalink
tests: Add decoupler marker
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Jun 3, 2024
1 parent e39e014 commit b6f1a08
Showing 1 changed file with 71 additions and 22 deletions.
93 changes: 71 additions & 22 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,29 @@ def EVAL(exprs, *args):
return processed[0] if isinstance(exprs, str) else processed


def get_testname(item):
if item.cls is not None:
return "%s::%s::%s" % (item.fspath, item.cls.__name__, item.name)
else:
return "%s::%s" % (item.fspath, item.name)


def set_run_reset(env_vars, call):
old_env_vars = {k: os.environ.get(k, None) for k in env_vars}

os.environ.update(env_vars)
os.environ['DEVITO_PYTEST_FLAG'] = '1'

try:
check_call(call)
return True
except:
return False
finally:
os.environ['DEVITO_PYTEST_FLAG'] = '0'
os.environ.update({k: v for k, v in old_env_vars.items() if v is not None})


def parallel(item, m):
"""
Run a test in parallel. Readapted from:
Expand All @@ -141,14 +164,12 @@ def parallel(item, m):
else:
raise ValueError("Can't run test: unexpected mode `%s`" % m)

env_vars = {'DEVITO_MPI': scheme}

pyversion = sys.executable
testname = get_testname(item)
# Only spew tracebacks on rank 0.
# Run xfailing tests to ensure that errors are reported to calling process
if item.cls is not None:
testname = "%s::%s::%s" % (item.fspath, item.cls.__name__, item.name)
else:
testname = "%s::%s" % (item.fspath, item.name)

args = ["-n", "1", pyversion, "-m", "pytest", "--no-summary", "-s",
"--runxfail", "-qq", testname]
if nprocs > 1:
Expand All @@ -161,16 +182,24 @@ def parallel(item, m):
else:
call = [mpi_exec] + args

# Tell the MPI ranks that they are running a parallel test
os.environ['DEVITO_MPI'] = scheme
try:
check_call(call)
res = True
except:
res = False
finally:
os.environ['DEVITO_MPI'] = '0'
return res
return set_run_reset(env_vars, call)


def decoupler(item, m):
"""
Run a test in decoupled mode.
"""
mpi_exec = 'mpiexec'
assert sniff_mpi_distro(mpi_exec) != 'unknown', "Decoupled tests require MPI"

env_vars = {'DEVITO_DECOUPLER': '1'}
if isinstance(m, int):
env_vars['DEVITO_DECOUPLER_WORKERS'] = str(m)

testname = get_testname(item)
call = ["pytest", "--no-summary", "-s", "--runxfail", testname]

return set_run_reset(env_vars, call)


def pytest_configure(config):
Expand All @@ -179,6 +208,10 @@ def pytest_configure(config):
"markers",
"parallel(mode): mark test to run in parallel"
)
config.addinivalue_line(
"markers",
"decoupler(mode): mark test to run in decoupled mode",
)


def pytest_generate_tests(metafunc):
Expand All @@ -187,26 +220,37 @@ def pytest_generate_tests(metafunc):
if 'mode' in metafunc.fixturenames:
markers = metafunc.definition.iter_markers()
for marker in markers:
if marker.name == 'parallel':
if marker.name in ('parallel', 'decoupler'):
mode = list(as_tuple(marker.kwargs.get('mode', 2)))
metafunc.parametrize("mode", mode)


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_call(item):
partest = os.environ.get('DEVITO_MPI', 0)
inside_pytest_marker = os.environ.get('DEVITO_PYTEST_FLAG', 0)
try:
partest = int(partest)
inside_pytest_marker = int(inside_pytest_marker)
except ValueError:
pass

if item.get_closest_marker("parallel") and not partest:
if inside_pytest_marker:
outcome = yield

elif item.get_closest_marker("parallel"):
# Spawn parallel processes to run test
outcome = parallel(item, item.funcargs['mode'])
if outcome:
pytest.skip(f"{item} success in parallel")
else:
pytest.fail(f"{item} failed in parallel")

elif item.get_closest_marker("decoupler"):
outcome = decoupler(item, item.funcargs.get('mode'))
if outcome:
pytest.skip(f"{item} success in decoupled mode")
else:
pytest.fail(f"{item} failed in decoupled mode")

else:
outcome = yield

Expand All @@ -215,12 +259,17 @@ def pytest_runtest_call(item):
def pytest_runtest_makereport(item, call):
outcome = yield
result = outcome.get_result()
partest = os.environ.get('DEVITO_MPI', 0)

inside_pytest_marker = os.environ.get('DEVITO_PYTEST_FLAG', 0)
try:
partest = int(partest)
inside_pytest_marker = int(inside_pytest_marker)
except ValueError:
pass
if item.get_closest_marker("parallel") and not partest:
if inside_pytest_marker:
return

if item.get_closest_marker("parallel") or \
item.get_closest_marker("decoupler"):
if call.when == 'call' and result.outcome == 'skipped':
result.outcome = 'passed'

Expand Down

0 comments on commit b6f1a08

Please sign in to comment.