Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add cancel cp option to camelyon benchmark #218

Merged
merged 4 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/camelyon/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def fed_avg(params: dict, train_folder: Path, test_folder: Path):
model=model,
mode=exp_params["mode"],
cp_name=exp_params["cp_name"],
cancel_cp=exp_params["cancel_cp"],
torch_gpu=exp_params["torch_gpu"],
)

Expand Down
2 changes: 2 additions & 0 deletions benchmark/camelyon/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def parse_params() -> dict:
default=Path(__file__).resolve().parents[1] / "data",
help="Path to the data",
)
parser.add_argument("--cancel-cp", type=bool, default=False, help="Remote only: cancel the CP after registration")
parser.add_argument("--torch-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support")
parser.add_argument(
"--cp-name",
Expand All @@ -95,6 +96,7 @@ def parse_params() -> dict:
params["nb_train_data_samples"] = args.nb_train_data_samples
params["nb_test_data_samples"] = args.nb_test_data_samples
params["data_path"] = args.data_path
params["cancel_cp"] = args.cancel_cp
params["torch_gpu"] = args.torch_gpu
params["cp_name"] = args.cp_name

Expand Down
6 changes: 6 additions & 0 deletions benchmark/camelyon/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import numpy as np
import substra
import torch
import weldon_fedavg
from common import benchmark_metrics
Expand Down Expand Up @@ -42,6 +43,7 @@ def substrafl_fed_avg(
credentials_path: Path,
asset_keys_path: Path,
cp_name: Optional[str],
cancel_cp: bool = False,
torch_gpu: bool = False,
) -> benchmark_metrics.BenchmarkResults:
"""Execute Weldon algorithm for a fed avg strategy with substrafl API.
Expand All @@ -65,6 +67,7 @@ def substrafl_fed_avg(
asset_keys_path (Path): Remote only: path to asset key file. If un existent, it will be created.
Otherwise, all present keys in this fill will be reused per Substra in remote mode.
cp_name ben): (Optional[str]): Compute Plan name to display
cancel_cp (bool): if set to True, the CP will be canceled as soon as it's registered. Only work for remote mode.
torch_gpu (bool): Use GPU default index for pytorch
Returns:
dict: Results of the experiment.
Expand Down Expand Up @@ -140,6 +143,9 @@ def accuracy(data_from_opener, predictions):
name=cp_name,
)

if cancel_cp and clients[0].backend_mode == substra.BackendType.REMOTE:
clients[0].cancel_compute_plan(key=compute_plan.key)

clients[0].wait_compute_plan(key=compute_plan.key, raise_on_failure=True)

performances = clients[1].get_performances(key=compute_plan.key)
Expand Down
Loading