diff --git a/src/py/flwr/server/strategy/default.py b/src/py/flwr/server/strategy/default.py index 82d34921fa05..2ecc2d46a6e9 100644 --- a/src/py/flwr/server/strategy/default.py +++ b/src/py/flwr/server/strategy/default.py @@ -43,6 +43,7 @@ def __init__( on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True, + initial_parameters: Optional[Weights] = None, ) -> None: super().__init__( fraction_fit=fraction_fit, @@ -54,6 +55,7 @@ def __init__( on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures, + initial_parameters=initial_parameters, ) warning = """ DEPRECATION WARNING: DefaultStrategy is deprecated, migrate to FedAvg. diff --git a/src/py/flwr/server/strategy/fast_and_slow.py b/src/py/flwr/server/strategy/fast_and_slow.py index 3bfe863a1658..37072038f0e5 100644 --- a/src/py/flwr/server/strategy/fast_and_slow.py +++ b/src/py/flwr/server/strategy/fast_and_slow.py @@ -70,6 +70,7 @@ def __init__( r_slow: int = 1, t_fast: int = 10, t_slow: int = 10, + initial_parameters: Optional[Weights] = None, ) -> None: super().__init__( fraction_fit=fraction_fit, @@ -80,6 +81,7 @@ def __init__( eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, + initial_parameters=initial_parameters, ) self.min_completion_rate_fit = min_completion_rate_fit self.min_completion_rate_evaluate = min_completion_rate_evaluate diff --git a/src/py/flwr/server/strategy/fault_tolerant_fedavg.py b/src/py/flwr/server/strategy/fault_tolerant_fedavg.py index 8865cc209653..17262701bacb 100644 --- a/src/py/flwr/server/strategy/fault_tolerant_fedavg.py +++ b/src/py/flwr/server/strategy/fault_tolerant_fedavg.py @@ -40,6 +40,7 @@ def __init__( on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, min_completion_rate_fit: float = 0.5, min_completion_rate_evaluate: float = 0.5, + initial_parameters: Optional[Weights] = None, ) -> None: super().__init__( fraction_fit=fraction_fit, @@ -51,6 +52,7 @@ def __init__( on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=True, + initial_parameters=initial_parameters, ) self.completion_rate_fit = min_completion_rate_fit self.completion_rate_evaluate = min_completion_rate_evaluate diff --git a/src/py/flwr/server/strategy/fedadagrad.py b/src/py/flwr/server/strategy/fedadagrad.py index 066fa18b12f0..83a451d9a6be 100644 --- a/src/py/flwr/server/strategy/fedadagrad.py +++ b/src/py/flwr/server/strategy/fedadagrad.py @@ -49,7 +49,7 @@ def __init__( on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True, - current_weights: Weights, + initial_parameters: Weights, eta: float = 1e-1, eta_l: float = 1e-1, tau: float = 1e-9, @@ -77,7 +77,7 @@ def __init__( Function used to configure validation. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - current_weights (Weights): Current set of weights from the server. + initial_parameters (Weights): Initial set of weights from the server. eta (float, optional): Server-side learning rate. Defaults to 1e-1. eta_l (float, optional): Client-side learning rate. Defaults to 1e-1. tau (float, optional): Controls the algorithm's degree of adaptability. @@ -93,7 +93,7 @@ def __init__( on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures, - current_weights=current_weights, + initial_parameters=initial_parameters, eta=eta, eta_l=eta_l, tau=tau, diff --git a/src/py/flwr/server/strategy/fedadagrad_test.py b/src/py/flwr/server/strategy/fedadagrad_test.py index d4b5a986dad4..a3d7947c5fec 100644 --- a/src/py/flwr/server/strategy/fedadagrad_test.py +++ b/src/py/flwr/server/strategy/fedadagrad_test.py @@ -32,7 +32,7 @@ def test_aggregate_fit() -> None: # Prepare previous_weights: Weights = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] strategy = FedAdagrad( - eta=0.1, eta_l=0.316, tau=0.5, current_weights=previous_weights + eta=0.1, eta_l=0.316, tau=0.5, initial_parameters=previous_weights ) param_0: Parameters = weights_to_parameters( [array([0.2, 0.2, 0.2, 0.2], dtype=float32)] diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index 7f7285ecbfbb..c915e698a42f 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -77,8 +77,7 @@ def __init__( Function used to configure validation. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - initialize_parameters_fn (Callable[[], Optional[Weights]], optional): - Function used to initialize global model parameters. Defaults to None. + initial_parameters (Weights, optional): Initial global model parameters. """ super().__init__() self.min_fit_clients = min_fit_clients diff --git a/src/py/flwr/server/strategy/fedfs_v0.py b/src/py/flwr/server/strategy/fedfs_v0.py index c95f0e97653c..238247d34be8 100644 --- a/src/py/flwr/server/strategy/fedfs_v0.py +++ b/src/py/flwr/server/strategy/fedfs_v0.py @@ -63,6 +63,7 @@ def __init__( r_slow: int = 1, t_fast: int = 10, t_slow: int = 10, + initial_parameters: Optional[Weights] = None, ) -> None: super().__init__( fraction_fit=fraction_fit, @@ -73,6 +74,7 @@ def __init__( eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, + initial_parameters=initial_parameters, ) self.min_completion_rate_fit = min_completion_rate_fit self.min_completion_rate_evaluate = min_completion_rate_evaluate diff --git a/src/py/flwr/server/strategy/fedfs_v1.py b/src/py/flwr/server/strategy/fedfs_v1.py index c2a8db0d4e91..3d623cac7e26 100644 --- a/src/py/flwr/server/strategy/fedfs_v1.py +++ b/src/py/flwr/server/strategy/fedfs_v1.py @@ -69,6 +69,7 @@ def __init__( r_slow: int = 1, t_max: int = 10, use_past_contributions: bool = False, + initial_parameters: Optional[Weights] = None, ) -> None: super().__init__( fraction_fit=fraction_fit, @@ -79,6 +80,7 @@ def __init__( eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, + initial_parameters=initial_parameters, ) self.min_completion_rate_fit = min_completion_rate_fit self.min_completion_rate_evaluate = min_completion_rate_evaluate diff --git a/src/py/flwr/server/strategy/fedopt.py b/src/py/flwr/server/strategy/fedopt.py index 8911c44637e1..2928dffc003f 100644 --- a/src/py/flwr/server/strategy/fedopt.py +++ b/src/py/flwr/server/strategy/fedopt.py @@ -42,7 +42,7 @@ def __init__( on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True, - current_weights: Weights, + initial_parameters: Weights, eta: float = 1e-1, eta_l: float = 1e-1, tau: float = 1e-9, @@ -70,24 +70,25 @@ def __init__( Function used to configure validation. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - current_weights (Weights): Current set of weights from the server. + initial_parameters (Weights): Initial set of parameters from the server. eta (float, optional): Server-side learning rate. Defaults to 1e-1. eta_l (float, optional): Client-side learning rate. Defaults to 1e-1. tau (float, optional): Controls the algorithm's degree of adaptability. Defaults to 1e-9. """ super().__init__( - fraction_fit, - fraction_eval, - min_fit_clients, - min_eval_clients, - min_available_clients, - eval_fn, - on_fit_config_fn, - on_evaluate_config_fn, - accept_failures, + fraction_fit=fraction_fit, + fraction_eval=fraction_eval, + min_fit_clients=min_fit_clients, + min_eval_clients=min_eval_clients, + min_available_clients=min_available_clients, + eval_fn=eval_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, ) - self.current_weights = current_weights + self.current_weights = initial_parameters self.eta = eta self.eta_l = eta_l self.tau = tau diff --git a/src/py/flwr/server/strategy/qffedavg.py b/src/py/flwr/server/strategy/qffedavg.py index b3c49635dc47..799f88ea6778 100644 --- a/src/py/flwr/server/strategy/qffedavg.py +++ b/src/py/flwr/server/strategy/qffedavg.py @@ -56,8 +56,20 @@ def __init__( on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True, + initial_parameters: Optional[Weights] = None, ) -> None: - super().__init__() + super().__init__( + fraction_fit=fraction_fit, + fraction_eval=fraction_eval, + min_fit_clients=min_fit_clients, + min_eval_clients=min_eval_clients, + min_available_clients=min_available_clients, + eval_fn=eval_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, + ) self.min_fit_clients = min_fit_clients self.min_eval_clients = min_eval_clients self.fraction_fit = fraction_fit