Skip to content

Commit

Permalink
Add initial_parameters argument to built-in strategies (#663)
Browse files Browse the repository at this point in the history
Co-authored-by: Pedro Porto Buarque de Gusmão <[email protected]>
  • Loading branch information
danieljanes and pedropgusmao authored Mar 9, 2021
1 parent 3fc784d commit 715be72
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 19 deletions.
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Weights] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand All @@ -54,6 +55,7 @@ def __init__(
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
)
warning = """
DEPRECATION WARNING: DefaultStrategy is deprecated, migrate to FedAvg.
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/fast_and_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
r_slow: int = 1,
t_fast: int = 10,
t_slow: int = 10,
initial_parameters: Optional[Weights] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand All @@ -80,6 +81,7 @@ def __init__(
eval_fn=eval_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
initial_parameters=initial_parameters,
)
self.min_completion_rate_fit = min_completion_rate_fit
self.min_completion_rate_evaluate = min_completion_rate_evaluate
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/fault_tolerant_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
min_completion_rate_fit: float = 0.5,
min_completion_rate_evaluate: float = 0.5,
initial_parameters: Optional[Weights] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand All @@ -51,6 +52,7 @@ def __init__(
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=True,
initial_parameters=initial_parameters,
)
self.completion_rate_fit = min_completion_rate_fit
self.completion_rate_evaluate = min_completion_rate_evaluate
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/server/strategy/fedadagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
current_weights: Weights,
initial_parameters: Weights,
eta: float = 1e-1,
eta_l: float = 1e-1,
tau: float = 1e-9,
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
Function used to configure validation. Defaults to None.
accept_failures (bool, optional): Whether or not accept rounds
containing failures. Defaults to True.
current_weights (Weights): Current set of weights from the server.
initial_parameters (Weights): Initial set of weights from the server.
eta (float, optional): Server-side learning rate. Defaults to 1e-1.
eta_l (float, optional): Client-side learning rate. Defaults to 1e-1.
tau (float, optional): Controls the algorithm's degree of adaptability.
Expand All @@ -93,7 +93,7 @@ def __init__(
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
current_weights=current_weights,
initial_parameters=initial_parameters,
eta=eta,
eta_l=eta_l,
tau=tau,
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/strategy/fedadagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_aggregate_fit() -> None:
# Prepare
previous_weights: Weights = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)]
strategy = FedAdagrad(
eta=0.1, eta_l=0.316, tau=0.5, current_weights=previous_weights
eta=0.1, eta_l=0.316, tau=0.5, initial_parameters=previous_weights
)
param_0: Parameters = weights_to_parameters(
[array([0.2, 0.2, 0.2, 0.2], dtype=float32)]
Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/server/strategy/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def __init__(
Function used to configure validation. Defaults to None.
accept_failures (bool, optional): Whether or not accept rounds
containing failures. Defaults to True.
initialize_parameters_fn (Callable[[], Optional[Weights]], optional):
Function used to initialize global model parameters. Defaults to None.
initial_parameters (Weights, optional): Initial global model parameters.
"""
super().__init__()
self.min_fit_clients = min_fit_clients
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/fedfs_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
r_slow: int = 1,
t_fast: int = 10,
t_slow: int = 10,
initial_parameters: Optional[Weights] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand All @@ -73,6 +74,7 @@ def __init__(
eval_fn=eval_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
initial_parameters=initial_parameters,
)
self.min_completion_rate_fit = min_completion_rate_fit
self.min_completion_rate_evaluate = min_completion_rate_evaluate
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/fedfs_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
r_slow: int = 1,
t_max: int = 10,
use_past_contributions: bool = False,
initial_parameters: Optional[Weights] = None,
) -> None:
super().__init__(
fraction_fit=fraction_fit,
Expand All @@ -79,6 +80,7 @@ def __init__(
eval_fn=eval_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
initial_parameters=initial_parameters,
)
self.min_completion_rate_fit = min_completion_rate_fit
self.min_completion_rate_evaluate = min_completion_rate_evaluate
Expand Down
25 changes: 13 additions & 12 deletions src/py/flwr/server/strategy/fedopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
current_weights: Weights,
initial_parameters: Weights,
eta: float = 1e-1,
eta_l: float = 1e-1,
tau: float = 1e-9,
Expand Down Expand Up @@ -70,24 +70,25 @@ def __init__(
Function used to configure validation. Defaults to None.
accept_failures (bool, optional): Whether or not accept rounds
containing failures. Defaults to True.
current_weights (Weights): Current set of weights from the server.
initial_parameters (Weights): Initial set of parameters from the server.
eta (float, optional): Server-side learning rate. Defaults to 1e-1.
eta_l (float, optional): Client-side learning rate. Defaults to 1e-1.
tau (float, optional): Controls the algorithm's degree of adaptability.
Defaults to 1e-9.
"""
super().__init__(
fraction_fit,
fraction_eval,
min_fit_clients,
min_eval_clients,
min_available_clients,
eval_fn,
on_fit_config_fn,
on_evaluate_config_fn,
accept_failures,
fraction_fit=fraction_fit,
fraction_eval=fraction_eval,
min_fit_clients=min_fit_clients,
min_eval_clients=min_eval_clients,
min_available_clients=min_available_clients,
eval_fn=eval_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
)
self.current_weights = current_weights
self.current_weights = initial_parameters
self.eta = eta
self.eta_l = eta_l
self.tau = tau
Expand Down
14 changes: 13 additions & 1 deletion src/py/flwr/server/strategy/qffedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,20 @@ def __init__(
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Weights] = None,
) -> None:
super().__init__()
super().__init__(
fraction_fit=fraction_fit,
fraction_eval=fraction_eval,
min_fit_clients=min_fit_clients,
min_eval_clients=min_eval_clients,
min_available_clients=min_available_clients,
eval_fn=eval_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
)
self.min_fit_clients = min_fit_clients
self.min_eval_clients = min_eval_clients
self.fraction_fit = fraction_fit
Expand Down

0 comments on commit 715be72

Please sign in to comment.