From 03b2fa2cb198df42256f9025a922350c96e9f776 Mon Sep 17 00:00:00 2001 From: Thibault Camalon <135698225+thbcmlowk@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:00:44 +0200 Subject: [PATCH] fix(benchmark/camelyon): actually rely on gpu setup at the Dependency level Signed-off-by: Thibault Camalon <135698225+thbcmlowk@users.noreply.github.com> --- benchmark/camelyon/benchmarks.py | 2 +- benchmark/camelyon/common/utils.py | 2 +- benchmark/camelyon/workflows.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/benchmark/camelyon/benchmarks.py b/benchmark/camelyon/benchmarks.py index 76ea8176..228d29c0 100644 --- a/benchmark/camelyon/benchmarks.py +++ b/benchmark/camelyon/benchmarks.py @@ -64,7 +64,7 @@ def fed_avg(params: dict, train_folder: Path, test_folder: Path): mode=exp_params["mode"], cp_name=exp_params["cp_name"], cancel_cp=exp_params["cancel_cp"], - torch_gpu=exp_params["torch_gpu"], + use_gpu=exp_params["use_gpu"], ) if exp_params["skip_pure_torch"]: diff --git a/benchmark/camelyon/common/utils.py b/benchmark/camelyon/common/utils.py index 1ed69494..e2e9e9a6 100644 --- a/benchmark/camelyon/common/utils.py +++ b/benchmark/camelyon/common/utils.py @@ -81,7 +81,7 @@ def parse_params() -> dict: default=False, help="Remote only: cancel the CP after registration", ) - parser.add_argument("--torch-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support") + parser.add_argument("--use-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support") parser.add_argument( "--skip-pure-torch", action="store_true", diff --git a/benchmark/camelyon/workflows.py b/benchmark/camelyon/workflows.py index 23f805be..43a6a2ef 100644 --- a/benchmark/camelyon/workflows.py +++ b/benchmark/camelyon/workflows.py @@ -44,7 +44,7 @@ def substrafl_fed_avg( asset_keys_path: Path, cp_name: Optional[str], cancel_cp: bool = False, - torch_gpu: bool = False, + use_gpu: bool = False, ) -> benchmark_metrics.BenchmarkResults: """Execute Weldon algorithm for a fed avg strategy with substrafl API. @@ -68,7 +68,7 @@ def substrafl_fed_avg( Otherwise, all present keys in this fill will be reused per Substra in remote mode. cp_name ben): (Optional[str]): Compute Plan name to display cancel_cp (bool): if set to True, the CP will be canceled as soon as it's registered. Only work for remote mode. - torch_gpu (bool): Use GPU default index for pytorch + use_gpu (bool): Use GPU for Dependency object Returns: dict: Results of the experiment. """ @@ -97,7 +97,7 @@ def substrafl_fed_avg( "torch==2.3.0", "scikit-learn==1.5.1", ] - if not torch_gpu: + if not use_gpu: pypi_dependencies += ["--extra-index-url https://download.pytorch.org/whl/cpu"] # Dependencies @@ -108,6 +108,7 @@ def substrafl_fed_avg( # Keeping editable_mode=True to ensure nightly test benchmarks are ran against main substrafl git ref editable_mode=True, compile=True, + use_gpu=use_gpu, ) # Metrics