diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index 0ce969a38ea..9d9dd60e0d6 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -1,5 +1,6 @@ # stdlib import json +import os from pathlib import Path from unittest import mock @@ -46,6 +47,14 @@ def remove_file(filepath: Path): filepath.unlink(missing_ok=True) +# Pytest hook to set the number of workers for xdist +def pytest_xdist_auto_num_workers(config): + num = config.option.numprocesses + if num == "auto" or num == "logical": + return os.cpu_count() + return None + + @pytest.fixture(autouse=True) def protocol_file(): random_name = sy.UID().to_string()