From e77e190043240334ce770753db73b9a39ee29b6b Mon Sep 17 00:00:00 2001 From: SdgJlbl Date: Mon, 19 Feb 2024 10:05:54 +0100 Subject: [PATCH] chore: rename compute task status (#188) Signed-off-by: SdgJlbl --- substrafl/model_loading.py | 4 ++-- tests/test_model_loading.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/substrafl/model_loading.py b/substrafl/model_loading.py index f38ad42b..ce5d88ef 100644 --- a/substrafl/model_loading.py +++ b/substrafl/model_loading.py @@ -10,7 +10,7 @@ import substra import substratools -from substra.sdk.models import Status +from substra.sdk.models import ComputeTaskStatus import substrafl from substrafl import exceptions @@ -356,7 +356,7 @@ def _download_task_output_files( rank_idx=rank_idx, tag=task_type, ) - if task.status is not Status.done: + if task.status is not ComputeTaskStatus.done: raise exceptions.UnfinishedTaskError( f"Can't download algo files form task {task.key} as it is " f"in status {task.status}" ) diff --git a/tests/test_model_loading.py b/tests/test_model_loading.py index 838e54a3..de93cbb9 100644 --- a/tests/test_model_loading.py +++ b/tests/test_model_loading.py @@ -93,7 +93,7 @@ def fake_local_train_task(trunk_model): permissions=substra.models.Permissions(process={"public": True, "authorized_ids": []}), value=trunk_model ), } - local_train_task.status = substra.models.Status.done + local_train_task.status = substra.models.ComputeTaskStatus.done return local_train_task @@ -116,7 +116,7 @@ def fake_aggregate_task(trunk_model): permissions=substra.models.Permissions(process={"public": True, "authorized_ids": []}), value=model ), } - aggregate_task.status = substra.models.Status.done + aggregate_task.status = substra.models.ComputeTaskStatus.done return aggregate_task @@ -447,7 +447,9 @@ def test_load_model_dependency(algo_files_with_local_dependency, is_dependency_u assert res == "hello world" -@pytest.mark.parametrize("status", [e.value for e in substra.models.Status if e.value != substra.models.Status.done]) +@pytest.mark.parametrize( + "status", [e.value for e in substra.models.ComputeTaskStatus if e.value != substra.models.ComputeTaskStatus.done] +) def test_unfinished_task_error( fake_client, fake_compute_plan,