From b6f1a087cb0b2b71b0f0f4f2f8a3080519132220 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 3 Jun 2024 08:43:00 +0000 Subject: [PATCH] tests: Add decoupler marker --- conftest.py | 93 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/conftest.py b/conftest.py index c38829f6e8..c88248e2b4 100644 --- a/conftest.py +++ b/conftest.py @@ -122,6 +122,29 @@ def EVAL(exprs, *args): return processed[0] if isinstance(exprs, str) else processed +def get_testname(item): + if item.cls is not None: + return "%s::%s::%s" % (item.fspath, item.cls.__name__, item.name) + else: + return "%s::%s" % (item.fspath, item.name) + + +def set_run_reset(env_vars, call): + old_env_vars = {k: os.environ.get(k, None) for k in env_vars} + + os.environ.update(env_vars) + os.environ['DEVITO_PYTEST_FLAG'] = '1' + + try: + check_call(call) + return True + except: + return False + finally: + os.environ['DEVITO_PYTEST_FLAG'] = '0' + os.environ.update({k: v for k, v in old_env_vars.items() if v is not None}) + + def parallel(item, m): """ Run a test in parallel. Readapted from: @@ -141,14 +164,12 @@ def parallel(item, m): else: raise ValueError("Can't run test: unexpected mode `%s`" % m) + env_vars = {'DEVITO_MPI': scheme} + pyversion = sys.executable + testname = get_testname(item) # Only spew tracebacks on rank 0. # Run xfailing tests to ensure that errors are reported to calling process - if item.cls is not None: - testname = "%s::%s::%s" % (item.fspath, item.cls.__name__, item.name) - else: - testname = "%s::%s" % (item.fspath, item.name) - args = ["-n", "1", pyversion, "-m", "pytest", "--no-summary", "-s", "--runxfail", "-qq", testname] if nprocs > 1: @@ -161,16 +182,24 @@ def parallel(item, m): else: call = [mpi_exec] + args - # Tell the MPI ranks that they are running a parallel test - os.environ['DEVITO_MPI'] = scheme - try: - check_call(call) - res = True - except: - res = False - finally: - os.environ['DEVITO_MPI'] = '0' - return res + return set_run_reset(env_vars, call) + + +def decoupler(item, m): + """ + Run a test in decoupled mode. + """ + mpi_exec = 'mpiexec' + assert sniff_mpi_distro(mpi_exec) != 'unknown', "Decoupled tests require MPI" + + env_vars = {'DEVITO_DECOUPLER': '1'} + if isinstance(m, int): + env_vars['DEVITO_DECOUPLER_WORKERS'] = str(m) + + testname = get_testname(item) + call = ["pytest", "--no-summary", "-s", "--runxfail", testname] + + return set_run_reset(env_vars, call) def pytest_configure(config): @@ -179,6 +208,10 @@ def pytest_configure(config): "markers", "parallel(mode): mark test to run in parallel" ) + config.addinivalue_line( + "markers", + "decoupler(mode): mark test to run in decoupled mode", + ) def pytest_generate_tests(metafunc): @@ -187,26 +220,37 @@ def pytest_generate_tests(metafunc): if 'mode' in metafunc.fixturenames: markers = metafunc.definition.iter_markers() for marker in markers: - if marker.name == 'parallel': + if marker.name in ('parallel', 'decoupler'): mode = list(as_tuple(marker.kwargs.get('mode', 2))) metafunc.parametrize("mode", mode) @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_runtest_call(item): - partest = os.environ.get('DEVITO_MPI', 0) + inside_pytest_marker = os.environ.get('DEVITO_PYTEST_FLAG', 0) try: - partest = int(partest) + inside_pytest_marker = int(inside_pytest_marker) except ValueError: pass - if item.get_closest_marker("parallel") and not partest: + if inside_pytest_marker: + outcome = yield + + elif item.get_closest_marker("parallel"): # Spawn parallel processes to run test outcome = parallel(item, item.funcargs['mode']) if outcome: pytest.skip(f"{item} success in parallel") else: pytest.fail(f"{item} failed in parallel") + + elif item.get_closest_marker("decoupler"): + outcome = decoupler(item, item.funcargs.get('mode')) + if outcome: + pytest.skip(f"{item} success in decoupled mode") + else: + pytest.fail(f"{item} failed in decoupled mode") + else: outcome = yield @@ -215,12 +259,17 @@ def pytest_runtest_call(item): def pytest_runtest_makereport(item, call): outcome = yield result = outcome.get_result() - partest = os.environ.get('DEVITO_MPI', 0) + + inside_pytest_marker = os.environ.get('DEVITO_PYTEST_FLAG', 0) try: - partest = int(partest) + inside_pytest_marker = int(inside_pytest_marker) except ValueError: pass - if item.get_closest_marker("parallel") and not partest: + if inside_pytest_marker: + return + + if item.get_closest_marker("parallel") or \ + item.get_closest_marker("decoupler"): if call.when == 'call' and result.outcome == 'skipped': result.outcome = 'passed'