Skip to content

Commit

Permalink
chore: rename use_gpu to disable_gpu in TorchAlgo
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy committed Aug 12, 2024
1 parent 860e0c6 commit 256626e
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion substrafl/algorithms/pytorch/torch_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
seed: Optional[int] = None,
disable_gpu: bool = True,
disable_gpu: bool = False,
*args,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions substrafl/algorithms/pytorch/torch_fed_avg_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
with_batch_norm_parameters: bool = False,
seed: Optional[int] = None,
disable_gpu: bool = True,
disable_gpu: bool = False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
with_batch_norm_parameters (bool): Whether to include the batch norm layer parameters in the fed avg
strategy. Defaults to False.
seed (typing.Optional[int]): Seed set at the algo initialization on each organization. Defaults to None.
disable_gpu (bool): Whether to use the GPUs if they are available. Defaults to True.
disable_gpu (bool): Force to disable GPUs usage. Defaults to False.
"""
super().__init__(
*args,
Expand Down
4 changes: 2 additions & 2 deletions substrafl/algorithms/pytorch/torch_fed_pca_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
out_features: int,
batch_size: Optional[int] = None,
seed: int = 1,
disable_gpu: bool = True,
disable_gpu: bool = False,
*args,
**kwargs,
):
Expand All @@ -101,7 +101,7 @@ def __init__(
out_features (int): dimension to keep after PCA
batch_size (Optional[int]): mini-batch size
seed (int): random generator seed. The seed is mandatory. Default to 1.
disable_gpu (bool): whether to use GPU or not. Default to True.
disable_gpu (bool): force to disable GPUs usage. Defaults to False.
"""
self.in_features = in_features
self.out_features = out_features
Expand Down
4 changes: 2 additions & 2 deletions substrafl/algorithms/pytorch/torch_newton_raphson_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
l2_coeff: float = 0,
with_batch_norm_parameters: bool = False,
seed: Optional[int] = None,
disable_gpu: bool = True,
disable_gpu: bool = False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
with_batch_norm_parameters (bool): Whether to include the batch norm layer parameters in the Newton-Raphson
strategy. Defaults to False.
seed (typing.Optional[int]): Seed set at the algo initialization on each organization. Defaults to None.
disable_gpu (bool): Whether to use the GPUs if they are available. Defaults to True.
disable_gpu (bool): Force to disable GPUs usage. Defaults to False.
"""
assert "optimizer" not in kwargs, "Newton Raphson strategy does not uses optimizers"

Expand Down
4 changes: 2 additions & 2 deletions substrafl/algorithms/pytorch/torch_scaffold_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
with_batch_norm_parameters: bool = False,
c_update_rule: CUpdateRule = CUpdateRule.FAST,
seed: Optional[int] = None,
disable_gpu: bool = True,
disable_gpu: bool = False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
client control variate.
Defaults to CUpdateRule.FAST.
seed (typing.Optional[int]): Seed set at the algo initialization on each organization. Defaults to None.
disable_gpu (bool): Whether to use the GPUs if they are available. Defaults to True.
disable_gpu (bool): Force to disable GPUs usage. Defaults to False.
Raises:
:ref:`~substrafl.exceptions.NumUpdatesValueError`: If `num_updates` is inferior or equal to zero.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
dataset: torch.utils.data.Dataset,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
seed: Optional[int] = None,
disable_gpu: bool = True,
disable_gpu: bool = False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
scheduler (torch.optim.lr_scheduler._LRScheduler, Optional): A torch scheduler that will be called at every
batch. If None, no scheduler will be used. Defaults to None.
seed (typing.Optional[int]): Seed set at the algo initialization on each organization. Defaults to None.
disable_gpu (bool): Whether to use the GPUs if they are available. Defaults to True.
disable_gpu (bool): Force to disable GPUs usage. Defaults to False.
"""
super().__init__(
*args,
Expand Down
4 changes: 2 additions & 2 deletions tests/algorithms/pytorch/test_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def __init__(self):
disable_gpu=disable_gpu,
)
if disable_gpu:
assert self._device == torch.device("cuda")
else:
assert self._device == torch.device("cpu")
else:
assert self._device == torch.device("cuda")

@property
def strategies(self):
Expand Down

0 comments on commit 256626e

Please sign in to comment.