Skip to content

Commit

Permalink
chore: add cancel cp option to camelyon benchmark (#218)
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored Jun 20, 2024
1 parent 380127b commit 40f1eaf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions benchmark/camelyon/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def fed_avg(params: dict, train_folder: Path, test_folder: Path):
model=model,
mode=exp_params["mode"],
cp_name=exp_params["cp_name"],
cancel_cp=exp_params["cancel_cp"],
torch_gpu=exp_params["torch_gpu"],
)

Expand Down
2 changes: 2 additions & 0 deletions benchmark/camelyon/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def parse_params() -> dict:
default=Path(__file__).resolve().parents[1] / "data",
help="Path to the data",
)
parser.add_argument("--cancel-cp", type=bool, default=False, help="Remote only: cancel the CP after registration")
parser.add_argument("--torch-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support")
parser.add_argument(
"--skip-pure-torch",
Expand All @@ -100,6 +101,7 @@ def parse_params() -> dict:
params["nb_train_data_samples"] = args.nb_train_data_samples
params["nb_test_data_samples"] = args.nb_test_data_samples
params["data_path"] = args.data_path
params["cancel_cp"] = args.cancel_cp
params["torch_gpu"] = args.torch_gpu
params["skip_pure_torch"] = args.skip_pure_torch
params["cp_name"] = args.cp_name
Expand Down
6 changes: 6 additions & 0 deletions benchmark/camelyon/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import numpy as np
import substra
import torch
import weldon_fedavg
from common import benchmark_metrics
Expand Down Expand Up @@ -42,6 +43,7 @@ def substrafl_fed_avg(
credentials_path: Path,
asset_keys_path: Path,
cp_name: Optional[str],
cancel_cp: bool = False,
torch_gpu: bool = False,
) -> benchmark_metrics.BenchmarkResults:
"""Execute Weldon algorithm for a fed avg strategy with substrafl API.
Expand All @@ -65,6 +67,7 @@ def substrafl_fed_avg(
asset_keys_path (Path): Remote only: path to asset key file. If un existent, it will be created.
Otherwise, all present keys in this fill will be reused per Substra in remote mode.
cp_name ben): (Optional[str]): Compute Plan name to display
cancel_cp (bool): if set to True, the CP will be canceled as soon as it's registered. Only work for remote mode.
torch_gpu (bool): Use GPU default index for pytorch
Returns:
dict: Results of the experiment.
Expand Down Expand Up @@ -140,6 +143,9 @@ def accuracy(data_from_opener, predictions):
name=cp_name,
)

if cancel_cp and clients[0].backend_mode == substra.BackendType.REMOTE:
clients[0].cancel_compute_plan(key=compute_plan.key)

clients[0].wait_compute_plan(key=compute_plan.key, raise_on_failure=True)

performances = clients[1].get_performances(key=compute_plan.key)
Expand Down

0 comments on commit 40f1eaf

Please sign in to comment.