Skip to content

Commit

Permalink
Also test CustomWorkerConfig.from_str and from_dict
Browse files Browse the repository at this point in the history
in addition to from_path
  • Loading branch information
kiendang committed Nov 27, 2023
1 parent b41b083 commit 4b64f5f
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions packages/syft/tests/syft/custom_worker/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,31 @@
import yaml

# syft absolute
from syft.custom_worker.config import CustomBuildConfig
from syft.custom_worker.config import CustomWorkerConfig

# must follow the default values set in CustomBuildConfig class definition

# in Pydantic v2 this would just be model.model_dump(mode='json')
def to_json_like_dict(model: BaseModel) -> Dict[str, Any]:
return json.loads(model.json())


DEFAULT_BUILD_CONFIG = {
"gpu": False,
"python_packages": [],
"system_packages": [],
"custom_cmds": [],
}
# must follow the default values set in CustomBuildConfig class definition
assert DEFAULT_BUILD_CONFIG == to_json_like_dict(CustomBuildConfig())


# must be set to the default value of CustomWorkerConfig.version
DEFAULT_WORKER_CONFIG_VERSION = "1"
# must be set to the default value of CustomWorkerConfig.version
assert (
DEFAULT_WORKER_CONFIG_VERSION
== CustomWorkerConfig(build=CustomBuildConfig()).version
)


CUSTOM_BUILD_CONFIG = {
Expand Down Expand Up @@ -96,11 +108,6 @@ def get_full_build_config(build_config: Dict[str, Any]) -> Dict[str, Any]:
return {**DEFAULT_BUILD_CONFIG, **build_config}


# in Pydantic v2 this would just be model.model_dump(mode='json')
def to_json_like_dict(model: BaseModel) -> dict:
return json.loads(model.json())


@pytest.fixture
def worker_config(
build_config: Dict[str, Any], worker_config_version: Optional[str]
Expand All @@ -119,14 +126,30 @@ def worker_config_yaml(tmp_path: Path, worker_config: Dict[str, Any]) -> Path:
file_path.unlink()


METHODS = ["from_dict", "from_str", "from_path"]


@pytest.mark.parametrize("build_config", CUSTOM_BUILD_CONFIG_TEST_CASES)
@pytest.mark.parametrize("worker_config_version", ["2", None])
def test_load_custom_worker_config_file(
@pytest.mark.parametrize("method", METHODS)
def test_load_custom_worker_config(
build_config: Dict[str, Any],
worker_config_version: Optional[str],
worker_config_yaml: Path,
method: str,
) -> None:
parsed_worker_config_obj = CustomWorkerConfig.from_path(worker_config_yaml)
if method == "from_path":
parsed_worker_config_obj = CustomWorkerConfig.from_path(worker_config_yaml)
elif method == "from_str":
parsed_worker_config_obj = CustomWorkerConfig.from_str(
worker_config_yaml.read_text()
)
elif method == "from_dict":
with open(worker_config_yaml) as f:
config = yaml.safe_load(f)
parsed_worker_config_obj = CustomWorkerConfig.from_dict(config)
else:
raise ValueError(f"method must be one of {METHODS}")

worker_config_version = (
DEFAULT_WORKER_CONFIG_VERSION
Expand Down

0 comments on commit 4b64f5f

Please sign in to comment.