Skip to content

Commit

Permalink
chore: allow skip pure torch computation on benchmark camelyon (#219)
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored Jun 20, 2024
1 parent 2efbf3b commit 7b8e88f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
4 changes: 4 additions & 0 deletions benchmark/camelyon/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def fed_avg(params: dict, train_folder: Path, test_folder: Path):
torch_gpu=exp_params["torch_gpu"],
)

if exp_params["skip_pure_torch"]:
print("Skipping pure torch FedAvg computation and comparison to SusbtraFL FedAvg.")
return

torch_metrics = torch_fed_avg(
train_folder=train_folder,
test_folder=test_folder,
Expand Down
6 changes: 6 additions & 0 deletions benchmark/camelyon/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def parse_params() -> dict:
help="Path to the data",
)
parser.add_argument("--torch-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support")
parser.add_argument(
"--skip-pure-torch",
action="store_true",
help="Skip the pure torch computation part to only test substrafl implementation",
)
parser.add_argument(
"--cp-name",
type=str,
Expand All @@ -96,6 +101,7 @@ def parse_params() -> dict:
params["nb_test_data_samples"] = args.nb_test_data_samples
params["data_path"] = args.data_path
params["torch_gpu"] = args.torch_gpu
params["skip_pure_torch"] = args.skip_pure_torch
params["cp_name"] = args.cp_name

return params
Expand Down

0 comments on commit 7b8e88f

Please sign in to comment.