From 7b8e88f24c987cb2eb60035c47df0927908573d4 Mon Sep 17 00:00:00 2001 From: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:39:48 +0200 Subject: [PATCH] chore: allow skip pure torch computation on benchmark camelyon (#219) Signed-off-by: ThibaultFy --- benchmark/camelyon/benchmarks.py | 4 ++++ benchmark/camelyon/common/utils.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/benchmark/camelyon/benchmarks.py b/benchmark/camelyon/benchmarks.py index 888419f0..6a55a115 100644 --- a/benchmark/camelyon/benchmarks.py +++ b/benchmark/camelyon/benchmarks.py @@ -66,6 +66,10 @@ def fed_avg(params: dict, train_folder: Path, test_folder: Path): torch_gpu=exp_params["torch_gpu"], ) + if exp_params["skip_pure_torch"]: + print("Skipping pure torch FedAvg computation and comparison to SusbtraFL FedAvg.") + return + torch_metrics = torch_fed_avg( train_folder=train_folder, test_folder=test_folder, diff --git a/benchmark/camelyon/common/utils.py b/benchmark/camelyon/common/utils.py index 5e0a6607..8540080e 100644 --- a/benchmark/camelyon/common/utils.py +++ b/benchmark/camelyon/common/utils.py @@ -76,6 +76,11 @@ def parse_params() -> dict: help="Path to the data", ) parser.add_argument("--torch-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support") + parser.add_argument( + "--skip-pure-torch", + action="store_true", + help="Skip the pure torch computation part to only test substrafl implementation", + ) parser.add_argument( "--cp-name", type=str, @@ -96,6 +101,7 @@ def parse_params() -> dict: params["nb_test_data_samples"] = args.nb_test_data_samples params["data_path"] = args.data_path params["torch_gpu"] = args.torch_gpu + params["skip_pure_torch"] = args.skip_pure_torch params["cp_name"] = args.cp_name return params