-
Notifications
You must be signed in to change notification settings - Fork 936
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
370 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright 2020 Adap GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Federated Median (FedMedian) [Yin et al., 2018] strategy. | ||
Paper: https://arxiv.org/pdf/1803.01498v1.pdf | ||
""" | ||
|
||
|
||
from logging import WARNING | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
|
||
from flwr.common import ( | ||
FitRes, | ||
MetricsAggregationFn, | ||
NDArrays, | ||
Parameters, | ||
Scalar, | ||
ndarrays_to_parameters, | ||
parameters_to_ndarrays, | ||
) | ||
from flwr.common.logger import log | ||
from flwr.server.client_proxy import ClientProxy | ||
|
||
from .aggregate import aggregate_median | ||
from .fedavg import FedAvg | ||
|
||
WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ | ||
Setting `min_available_clients` lower than `min_fit_clients` or | ||
`min_evaluate_clients` can cause the server to fail when there are too few clients | ||
connected to the server. `min_available_clients` must be set to a value larger | ||
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`. | ||
""" | ||
|
||
# flake8: noqa: E501 | ||
class FedMedian(FedAvg): | ||
"""Configurable FedAvg with Momentum strategy implementation.""" | ||
|
||
# pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long | ||
def __init__( | ||
self, | ||
*, | ||
fraction_fit: float = 1.0, | ||
fraction_evaluate: float = 1.0, | ||
min_fit_clients: int = 2, | ||
min_evaluate_clients: int = 2, | ||
min_available_clients: int = 2, | ||
evaluate_fn: Optional[ | ||
Callable[ | ||
[int, NDArrays, Dict[str, Scalar]], | ||
Optional[Tuple[float, Dict[str, Scalar]]], | ||
] | ||
] = None, | ||
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, | ||
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, | ||
accept_failures: bool = True, | ||
initial_parameters: Optional[Parameters] = None, | ||
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, | ||
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, | ||
) -> None: | ||
"""Configurable FedMedian strategy. | ||
Implementation based on https://arxiv.org/pdf/1803.01498v1.pdf | ||
Parameters | ||
---------- | ||
fraction_fit : float, optional | ||
Fraction of clients used during training. Defaults to 0.1. | ||
fraction_evaluate : float, optional | ||
Fraction of clients used during validation. Defaults to 0.1. | ||
min_fit_clients : int, optional | ||
Minimum number of clients used during training. Defaults to 2. | ||
min_evaluate_clients : int, optional | ||
Minimum number of clients used during validation. Defaults to 2. | ||
min_available_clients : int, optional | ||
Minimum number of total clients in the system. Defaults to 2. | ||
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] | ||
Optional function used for validation. Defaults to None. | ||
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional | ||
Function used to configure training. Defaults to None. | ||
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional | ||
Function used to configure validation. Defaults to None. | ||
accept_failures : bool, optional | ||
Whether or not accept rounds containing failures. Defaults to True. | ||
initial_parameters : Parameters, optional | ||
Initial global model parameters. | ||
""" | ||
|
||
if ( | ||
min_fit_clients > min_available_clients | ||
or min_evaluate_clients > min_available_clients | ||
): | ||
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW) | ||
|
||
super().__init__( | ||
fraction_fit=fraction_fit, | ||
fraction_evaluate=fraction_evaluate, | ||
min_fit_clients=min_fit_clients, | ||
min_evaluate_clients=min_evaluate_clients, | ||
min_available_clients=min_available_clients, | ||
evaluate_fn=evaluate_fn, | ||
on_fit_config_fn=on_fit_config_fn, | ||
on_evaluate_config_fn=on_evaluate_config_fn, | ||
accept_failures=accept_failures, | ||
initial_parameters=initial_parameters, | ||
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, | ||
) | ||
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn | ||
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn | ||
|
||
def __repr__(self) -> str: | ||
rep = f"FedMedian(accept_failures={self.accept_failures})" | ||
return rep | ||
|
||
def aggregate_fit( | ||
self, | ||
server_round: int, | ||
results: List[Tuple[ClientProxy, FitRes]], | ||
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], | ||
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: | ||
"""Aggregate fit results using median.""" | ||
if not results: | ||
return None, {} | ||
# Do not aggregate if there are failures and failures are not accepted | ||
if not self.accept_failures and failures: | ||
return None, {} | ||
|
||
# Convert results | ||
weights_results = [ | ||
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) | ||
for _, fit_res in results | ||
] | ||
parameters_aggregated = ndarrays_to_parameters( | ||
aggregate_median(weights_results) | ||
) | ||
|
||
# Aggregate custom metrics if aggregation fn was provided | ||
metrics_aggregated = {} | ||
if self.fit_metrics_aggregation_fn: | ||
fit_metrics = [(res.num_examples, res.metrics) for _, res in results] | ||
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) | ||
elif server_round == 1: # Only log this warning once | ||
log(WARNING, "No fit_metrics_aggregation_fn provided") | ||
|
||
return parameters_aggregated, metrics_aggregated |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright 2020 Adap GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""FedMedian tests.""" | ||
|
||
from typing import List, Tuple | ||
from unittest.mock import MagicMock | ||
|
||
from numpy import array, float32 | ||
|
||
from flwr.common import ( | ||
Code, | ||
FitRes, | ||
NDArrays, | ||
Parameters, | ||
Status, | ||
ndarrays_to_parameters, | ||
parameters_to_ndarrays, | ||
) | ||
from flwr.server.client_proxy import ClientProxy | ||
from flwr.server.grpc_server.grpc_client_proxy import GrpcClientProxy | ||
|
||
from .fedmedian import FedMedian | ||
|
||
|
||
def test_fedmedian_num_fit_clients_20_available() -> None: | ||
"""Test num_fit_clients function.""" | ||
# Prepare | ||
strategy = FedMedian() | ||
expected = 20 | ||
|
||
# Execute | ||
actual, _ = strategy.num_fit_clients(num_available_clients=20) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_fedmedian_num_fit_clients_19_available() -> None: | ||
"""Test num_fit_clients function.""" | ||
# Prepare | ||
strategy = FedMedian() | ||
expected = 19 | ||
|
||
# Execute | ||
actual, _ = strategy.num_fit_clients(num_available_clients=19) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_fedmedian_num_fit_clients_10_available() -> None: | ||
"""Test num_fit_clients function.""" | ||
# Prepare | ||
strategy = FedMedian() | ||
expected = 10 | ||
|
||
# Execute | ||
actual, _ = strategy.num_fit_clients(num_available_clients=10) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_fedmedian_num_fit_clients_minimum() -> None: | ||
"""Test num_fit_clients function.""" | ||
# Prepare | ||
strategy = FedMedian() | ||
expected = 9 | ||
|
||
# Execute | ||
actual, _ = strategy.num_fit_clients(num_available_clients=9) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_fedmedian_num_evaluation_clients_40_available() -> None: | ||
"""Test num_evaluation_clients function.""" | ||
# Prepare | ||
strategy = FedMedian(fraction_evaluate=0.05) | ||
expected = 2 | ||
|
||
# Execute | ||
actual, _ = strategy.num_evaluation_clients(num_available_clients=40) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_fedmedian_num_evaluation_clients_39_available() -> None: | ||
"""Test num_evaluation_clients function.""" | ||
# Prepare | ||
strategy = FedMedian(fraction_evaluate=0.05) | ||
expected = 2 | ||
|
||
# Execute | ||
actual, _ = strategy.num_evaluation_clients(num_available_clients=39) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_fedmedian_num_evaluation_clients_20_available() -> None: | ||
"""Test num_evaluation_clients function.""" | ||
# Prepare | ||
strategy = FedMedian(fraction_evaluate=0.05) | ||
expected = 2 | ||
|
||
# Execute | ||
actual, _ = strategy.num_evaluation_clients(num_available_clients=20) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_fedmedian_num_evaluation_clients_minimum() -> None: | ||
"""Test num_evaluation_clients function.""" | ||
# Prepare | ||
strategy = FedMedian(fraction_evaluate=0.05) | ||
expected = 2 | ||
|
||
# Execute | ||
actual, _ = strategy.num_evaluation_clients(num_available_clients=19) | ||
|
||
# Assert | ||
assert expected == actual | ||
|
||
|
||
def test_aggregate_fit() -> None: | ||
"""Tests if FedMedian is aggregating correctly.""" | ||
# Prepare | ||
previous_weights: NDArrays = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] | ||
strategy = FedMedian( | ||
initial_parameters=ndarrays_to_parameters(previous_weights), | ||
) | ||
param_0: Parameters = ndarrays_to_parameters( | ||
[array([0.2, 0.2, 0.2, 0.2], dtype=float32)] | ||
) | ||
param_1: Parameters = ndarrays_to_parameters( | ||
[array([1.0, 1.0, 1.0, 1.0], dtype=float32)] | ||
) | ||
param_2: Parameters = ndarrays_to_parameters( | ||
[array([0.5, 0.5, 0.5, 0.5], dtype=float32)] | ||
) | ||
bridge = MagicMock() | ||
client_0 = GrpcClientProxy(cid="0", bridge=bridge) | ||
client_1 = GrpcClientProxy(cid="1", bridge=bridge) | ||
client_2 = GrpcClientProxy(cid="2", bridge=bridge) | ||
results: List[Tuple[ClientProxy, FitRes]] = [ | ||
( | ||
client_0, | ||
FitRes( | ||
status=Status(code=Code.OK, message="Success"), | ||
parameters=param_0, | ||
num_examples=5, | ||
metrics={}, | ||
), | ||
), | ||
( | ||
client_1, | ||
FitRes( | ||
status=Status(code=Code.OK, message="Success"), | ||
parameters=param_1, | ||
num_examples=5, | ||
metrics={}, | ||
), | ||
), | ||
( | ||
client_2, | ||
FitRes( | ||
status=Status(code=Code.OK, message="Success"), | ||
parameters=param_2, | ||
num_examples=5, | ||
metrics={}, | ||
), | ||
), | ||
] | ||
expected: NDArrays = [array([0.5, 0.5, 0.5, 0.5], dtype=float32)] | ||
|
||
# Execute | ||
actual_aggregated, _ = strategy.aggregate_fit( | ||
server_round=1, results=results, failures=[] | ||
) | ||
if actual_aggregated: | ||
actual_list = parameters_to_ndarrays(actual_aggregated) | ||
actual = actual_list[0] | ||
assert (actual == expected[0]).all() |