forked from adap/flower
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implementing FedHT as a custom aggregation strategy; required new fed…
…ht strategy and aggregation_hardthreshold method
- Loading branch information
1 parent
c036832
commit d4be452
Showing
2 changed files
with
332 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
# Copyright 2020 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Federated Hardthresholding (FedHT) | ||
""" | ||
|
||
|
||
from logging import WARNING | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
import numpy as np | ||
|
||
from flwr.common import ( | ||
EvaluateIns, | ||
EvaluateRes, | ||
FitIns, | ||
FitRes, | ||
MetricsAggregationFn, | ||
NDArrays, | ||
Parameters, | ||
Scalar, | ||
ndarrays_to_parameters, | ||
parameters_to_ndarrays, | ||
) | ||
from flwr.common.logger import log | ||
from flwr.server.client_manager import ClientManager | ||
from flwr.server.client_proxy import ClientProxy | ||
|
||
# from flwr.server.strategy.aggregate import aggregate, aggregate_inplace, weighted_loss_avg | ||
from flwr.server.strategy.aggregate import aggregate_inplace, weighted_loss_avg, aggregate_hardthreshold | ||
from flwr.server.strategy.strategy import Strategy | ||
|
||
WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ | ||
Setting `min_available_clients` lower than `min_fit_clients` or | ||
`min_evaluate_clients` can cause the server to fail when there are too few clients | ||
connected to the server. `min_available_clients` must be set to a value larger | ||
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`. | ||
""" | ||
|
||
|
||
# pylint: disable=line-too-long | ||
class FedHT(Strategy): | ||
"""Federated Hardthreshold strategy. | ||
Implementation based on https://arxiv.org/abs/1602.05629 | ||
Parameters | ||
---------- | ||
fraction_fit : float, optional | ||
Fraction of clients used during training. In case `min_fit_clients` | ||
is larger than `fraction_fit * available_clients`, `min_fit_clients` | ||
will still be sampled. Defaults to 1.0. | ||
fraction_evaluate : float, optional | ||
Fraction of clients used during validation. In case `min_evaluate_clients` | ||
is larger than `fraction_evaluate * available_clients`, | ||
`min_evaluate_clients` will still be sampled. Defaults to 1.0. | ||
min_fit_clients : int, optional | ||
Minimum number of clients used during training. Defaults to 2. | ||
min_evaluate_clients : int, optional | ||
Minimum number of clients used during validation. Defaults to 2. | ||
min_available_clients : int, optional | ||
Minimum number of total clients in the system. Defaults to 2. | ||
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]] | ||
Optional function used for validation. Defaults to None. | ||
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional | ||
Function used to configure training. Defaults to None. | ||
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional | ||
Function used to configure validation. Defaults to None. | ||
accept_failures : bool, optional | ||
Whether or not accept rounds containing failures. Defaults to True. | ||
initial_parameters : Parameters, optional | ||
Initial global model parameters. | ||
fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] | ||
Metrics aggregation function, optional. | ||
evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] | ||
Metrics aggregation function, optional. | ||
inplace : bool (default: True) | ||
Enable (True) or disable (False) in-place aggregation of model updates. | ||
""" | ||
|
||
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long | ||
def __init__( | ||
self, | ||
*, | ||
fraction_fit: float = 1.0, | ||
fraction_evaluate: float = 1.0, | ||
num_keep: int = 5, | ||
min_fit_clients: int = 2, | ||
min_evaluate_clients: int = 2, | ||
min_available_clients: int = 2, | ||
evaluate_fn: Optional[ | ||
Callable[ | ||
[int, NDArrays, Dict[str, Scalar]], | ||
Optional[Tuple[float, Dict[str, Scalar]]], | ||
] | ||
] = None, | ||
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, | ||
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, | ||
accept_failures: bool = True, | ||
initial_parameters: Optional[Parameters] = None, | ||
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, | ||
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, | ||
inplace: bool = True, | ||
) -> None: | ||
super().__init__() | ||
|
||
if ( | ||
min_fit_clients > min_available_clients | ||
or min_evaluate_clients > min_available_clients | ||
): | ||
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW) | ||
|
||
self.fraction_fit = fraction_fit | ||
self.fraction_evaluate = fraction_evaluate | ||
self.num_keep = num_keep | ||
self.min_fit_clients = min_fit_clients | ||
self.min_evaluate_clients = min_evaluate_clients | ||
self.min_available_clients = min_available_clients | ||
self.evaluate_fn = evaluate_fn | ||
self.on_fit_config_fn = on_fit_config_fn | ||
self.on_evaluate_config_fn = on_evaluate_config_fn | ||
self.accept_failures = accept_failures | ||
self.initial_parameters = initial_parameters | ||
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn | ||
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn | ||
self.inplace = inplace | ||
|
||
def __repr__(self) -> str: | ||
"""Compute a string representation of the strategy.""" | ||
rep = f"FedAvg(accept_failures={self.accept_failures})" | ||
return rep | ||
|
||
def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: | ||
"""Return the sample size and the required number of available clients.""" | ||
num_clients = int(num_available_clients * self.fraction_fit) | ||
return max(num_clients, self.min_fit_clients), self.min_available_clients | ||
|
||
def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: | ||
"""Use a fraction of available clients for evaluation.""" | ||
num_clients = int(num_available_clients * self.fraction_evaluate) | ||
return max(num_clients, self.min_evaluate_clients), self.min_available_clients | ||
|
||
def initialize_parameters( | ||
self, client_manager: ClientManager | ||
) -> Optional[Parameters]: | ||
"""Initialize global model parameters.""" | ||
initial_parameters = self.initial_parameters | ||
self.initial_parameters = None # Don't keep initial parameters in memory | ||
return initial_parameters | ||
|
||
def evaluate( | ||
self, server_round: int, parameters: Parameters | ||
) -> Optional[Tuple[float, Dict[str, Scalar]]]: | ||
"""Evaluate model parameters using an evaluation function.""" | ||
if self.evaluate_fn is None: | ||
# No evaluation function provided | ||
return None | ||
parameters_ndarrays = parameters_to_ndarrays(parameters) | ||
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {}) | ||
if eval_res is None: | ||
return None | ||
loss, metrics = eval_res | ||
return loss, metrics | ||
|
||
def configure_fit( | ||
self, server_round: int, parameters: Parameters, client_manager: ClientManager | ||
) -> List[Tuple[ClientProxy, FitIns]]: | ||
"""Configure the next round of training.""" | ||
config = {} | ||
if self.on_fit_config_fn is not None: | ||
# Custom fit config function provided | ||
config = self.on_fit_config_fn(server_round) | ||
fit_ins = FitIns(parameters, config) | ||
|
||
# Sample clients | ||
sample_size, min_num_clients = self.num_fit_clients( | ||
client_manager.num_available() | ||
) | ||
clients = client_manager.sample( | ||
num_clients=sample_size, min_num_clients=min_num_clients | ||
) | ||
|
||
# Return client/config pairs | ||
return [(client, fit_ins) for client in clients] | ||
|
||
def configure_evaluate( | ||
self, server_round: int, parameters: Parameters, client_manager: ClientManager | ||
) -> List[Tuple[ClientProxy, EvaluateIns]]: | ||
"""Configure the next round of evaluation.""" | ||
# Do not configure federated evaluation if fraction eval is 0. | ||
if self.fraction_evaluate == 0.0: | ||
return [] | ||
|
||
# Parameters and config | ||
config = {} | ||
if self.on_evaluate_config_fn is not None: | ||
# Custom evaluation config function provided | ||
config = self.on_evaluate_config_fn(server_round) | ||
evaluate_ins = EvaluateIns(parameters, config) | ||
|
||
# Sample clients | ||
sample_size, min_num_clients = self.num_evaluation_clients( | ||
client_manager.num_available() | ||
) | ||
clients = client_manager.sample( | ||
num_clients=sample_size, min_num_clients=min_num_clients | ||
) | ||
|
||
# Return client/config pairs | ||
return [(client, evaluate_ins) for client in clients] | ||
|
||
def aggregate_fit( | ||
self, | ||
server_round: int, | ||
results: List[Tuple[ClientProxy, FitRes]], | ||
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], | ||
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: | ||
"""Aggregate fit results using weighted average.""" | ||
if not results: | ||
return None, {} | ||
# Do not aggregate if there are failures and failures are not accepted | ||
if not self.accept_failures and failures: | ||
return None, {} | ||
|
||
if self.inplace: | ||
# Does in-place weighted average of results | ||
aggregated_ndarrays = aggregate_inplace(results) | ||
else: | ||
# Convert results | ||
weights_results = [ | ||
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) | ||
for _, fit_res in results | ||
] | ||
aggregated_ndarrays = aggregate_hardthreshold(weights_results, self.num_keep) | ||
|
||
parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays) | ||
|
||
# Aggregate custom metrics if aggregation fn was provided | ||
metrics_aggregated = {} | ||
if self.fit_metrics_aggregation_fn: | ||
fit_metrics = [(res.num_examples, res.metrics) for _, res in results] | ||
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) | ||
elif server_round == 1: # Only log this warning once | ||
log(WARNING, "No fit_metrics_aggregation_fn provided") | ||
|
||
return parameters_aggregated, metrics_aggregated | ||
|
||
def aggregate_evaluate( | ||
self, | ||
server_round: int, | ||
results: List[Tuple[ClientProxy, EvaluateRes]], | ||
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], | ||
) -> Tuple[Optional[float], Dict[str, Scalar]]: | ||
"""Aggregate evaluation losses using weighted average.""" | ||
if not results: | ||
return None, {} | ||
# Do not aggregate if there are failures and failures are not accepted | ||
if not self.accept_failures and failures: | ||
return None, {} | ||
|
||
# Aggregate loss | ||
loss_aggregated = weighted_loss_avg( | ||
[ | ||
(evaluate_res.num_examples, evaluate_res.loss) | ||
for _, evaluate_res in results | ||
] | ||
) | ||
|
||
# Aggregate custom metrics if aggregation fn was provided | ||
metrics_aggregated = {} | ||
if self.evaluate_metrics_aggregation_fn: | ||
eval_metrics = [(res.num_examples, res.metrics) for _, res in results] | ||
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) | ||
elif server_round == 1: # Only log this warning once | ||
log(WARNING, "No evaluate_metrics_aggregation_fn provided") | ||
|
||
return loss_aggregated, metrics_aggregated |