Skip to content

Commit

Permalink
fix(benchmark/camelyon): actually rely on gpu setup at the Dependency…
Browse files Browse the repository at this point in the history
… level

Signed-off-by: Thibault Camalon <[email protected]>
  • Loading branch information
thbcmlowk committed Sep 3, 2024
1 parent a7b598c commit 03b2fa2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion benchmark/camelyon/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def fed_avg(params: dict, train_folder: Path, test_folder: Path):
mode=exp_params["mode"],
cp_name=exp_params["cp_name"],
cancel_cp=exp_params["cancel_cp"],
torch_gpu=exp_params["torch_gpu"],
use_gpu=exp_params["use_gpu"],
)

if exp_params["skip_pure_torch"]:
Expand Down
2 changes: 1 addition & 1 deletion benchmark/camelyon/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def parse_params() -> dict:
default=False,
help="Remote only: cancel the CP after registration",
)
parser.add_argument("--torch-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support")
parser.add_argument("--use-gpu", action="store_true", help="Use PyTorch with GPU/CUDA support")
parser.add_argument(
"--skip-pure-torch",
action="store_true",
Expand Down
7 changes: 4 additions & 3 deletions benchmark/camelyon/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def substrafl_fed_avg(
asset_keys_path: Path,
cp_name: Optional[str],
cancel_cp: bool = False,
torch_gpu: bool = False,
use_gpu: bool = False,
) -> benchmark_metrics.BenchmarkResults:
"""Execute Weldon algorithm for a fed avg strategy with substrafl API.
Expand All @@ -68,7 +68,7 @@ def substrafl_fed_avg(
Otherwise, all present keys in this fill will be reused per Substra in remote mode.
cp_name ben): (Optional[str]): Compute Plan name to display
cancel_cp (bool): if set to True, the CP will be canceled as soon as it's registered. Only work for remote mode.
torch_gpu (bool): Use GPU default index for pytorch
use_gpu (bool): Use GPU for Dependency object
Returns:
dict: Results of the experiment.
"""
Expand Down Expand Up @@ -97,7 +97,7 @@ def substrafl_fed_avg(
"torch==2.3.0",
"scikit-learn==1.5.1",
]
if not torch_gpu:
if not use_gpu:
pypi_dependencies += ["--extra-index-url https://download.pytorch.org/whl/cpu"]

# Dependencies
Expand All @@ -108,6 +108,7 @@ def substrafl_fed_avg(
# Keeping editable_mode=True to ensure nightly test benchmarks are ran against main substrafl git ref
editable_mode=True,
compile=True,
use_gpu=use_gpu,
)

# Metrics
Expand Down

0 comments on commit 03b2fa2

Please sign in to comment.