Skip to content

Commit

Permalink
implementing FedHT as a custom aggregation strategy; required new fed…
Browse files Browse the repository at this point in the history
…ht strategy and aggregation_hardthreshold method
  • Loading branch information
chancejohnstone authored Aug 10, 2024
1 parent c036832 commit d4be452
Show file tree
Hide file tree
Showing 2 changed files with 332 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/py/flwr/server/strategy/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,48 @@ def _aggregate_n_closest_weights(
)[:beta_closest]
aggregated_weights.append(np.mean(beta_closest_weights, axis=0))
return aggregated_weights

def hardthreshold(weights_all, num_keep: int) -> NDArrays:
# np.set_printoptions(threshold=np.inf)
weights_prime=weights_all[0]
# print(weights_prime)
if num_keep > len(weights_prime):
raise ValueError("The number of parameters kept cannot be greater than the length of the vector.")

# Compute the magnitudes
magnitudes = np.abs(weights_prime)

# Get the k-th largest value in the vector
threshold = np.partition(magnitudes, -num_keep)[-num_keep]

# Create a new vector where values below the threshold are set to zero
params = np.where(magnitudes >= threshold, weights_prime, 0)

return params

def aggregate_hardthreshold(
results: List[Tuple[NDArrays, int]], num_keep: int) -> NDArrays:
"""
Applies hard thresholding to keep only the k largest weights in a client-weight vector. Fed-HT (Fed-IterHT) can be
found at https://arxiv.org/abs/2101.00052
"""
if num_keep <= 0:
raise ValueError("k must be a positive integer.")

"""Compute weighted average."""
# Calculate the total number of examples used during training
num_examples_total = sum(num_examples for (_, num_examples) in results)

# Create a list of weights, each multiplied by the related number of examples
weighted_weights = [
[layer * num_examples for layer in weights] for weights, num_examples in results
]

# Create a list of weights and perform hardthresholding
hold = [reduce(np.add, layer_updates) / num_examples_total for layer_updates in zip(*weighted_weights)]
params = [hardthreshold(layer_updates, num_keep) for layer_updates in zip(*hold)]

# result = np.array([params, hold[1]], dtype=object)
result: NDArrays = [np.array(params), hold[1]]

return result
287 changes: 287 additions & 0 deletions src/py/flwr/server/strategy/fedht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated Hardthresholding (FedHT)
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np

from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

# from flwr.server.strategy.aggregate import aggregate, aggregate_inplace, weighted_loss_avg
from flwr.server.strategy.aggregate import aggregate_inplace, weighted_loss_avg, aggregate_hardthreshold
from flwr.server.strategy.strategy import Strategy

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""


# pylint: disable=line-too-long
class FedHT(Strategy):
"""Federated Hardthreshold strategy.
Implementation based on https://arxiv.org/abs/1602.05629
Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. In case `min_fit_clients`
is larger than `fraction_fit * available_clients`, `min_fit_clients`
will still be sampled. Defaults to 1.0.
fraction_evaluate : float, optional
Fraction of clients used during validation. In case `min_evaluate_clients`
is larger than `fraction_evaluate * available_clients`,
`min_evaluate_clients` will still be sampled. Defaults to 1.0.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
fit_metrics_aggregation_fn : Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
inplace : bool (default: True)
Enable (True) or disable (False) in-place aggregation of model updates.
"""

# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
num_keep: int = 5,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
inplace: bool = True,
) -> None:
super().__init__()

if (
min_fit_clients > min_available_clients
or min_evaluate_clients > min_available_clients
):
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

self.fraction_fit = fraction_fit
self.fraction_evaluate = fraction_evaluate
self.num_keep = num_keep
self.min_fit_clients = min_fit_clients
self.min_evaluate_clients = min_evaluate_clients
self.min_available_clients = min_available_clients
self.evaluate_fn = evaluate_fn
self.on_fit_config_fn = on_fit_config_fn
self.on_evaluate_config_fn = on_evaluate_config_fn
self.accept_failures = accept_failures
self.initial_parameters = initial_parameters
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
self.inplace = inplace

def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
rep = f"FedAvg(accept_failures={self.accept_failures})"
return rep

def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Return the sample size and the required number of available clients."""
num_clients = int(num_available_clients * self.fraction_fit)
return max(num_clients, self.min_fit_clients), self.min_available_clients

def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Use a fraction of available clients for evaluation."""
num_clients = int(num_available_clients * self.fraction_evaluate)
return max(num_clients, self.min_evaluate_clients), self.min_available_clients

def initialize_parameters(
self, client_manager: ClientManager
) -> Optional[Parameters]:
"""Initialize global model parameters."""
initial_parameters = self.initial_parameters
self.initial_parameters = None # Don't keep initial parameters in memory
return initial_parameters

def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_fn is None:
# No evaluation function provided
return None
parameters_ndarrays = parameters_to_ndarrays(parameters)
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics

def configure_fit(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
config = {}
if self.on_fit_config_fn is not None:
# Custom fit config function provided
config = self.on_fit_config_fn(server_round)
fit_ins = FitIns(parameters, config)

# Sample clients
sample_size, min_num_clients = self.num_fit_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)

# Return client/config pairs
return [(client, fit_ins) for client in clients]

def configure_evaluate(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
"""Configure the next round of evaluation."""
# Do not configure federated evaluation if fraction eval is 0.
if self.fraction_evaluate == 0.0:
return []

# Parameters and config
config = {}
if self.on_evaluate_config_fn is not None:
# Custom evaluation config function provided
config = self.on_evaluate_config_fn(server_round)
evaluate_ins = EvaluateIns(parameters, config)

# Sample clients
sample_size, min_num_clients = self.num_evaluation_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)

# Return client/config pairs
return [(client, evaluate_ins) for client in clients]

def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

if self.inplace:
# Does in-place weighted average of results
aggregated_ndarrays = aggregate_inplace(results)
else:
# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
aggregated_ndarrays = aggregate_hardthreshold(weights_results, self.num_keep)

parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays)

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")

return parameters_aggregated, metrics_aggregated

def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

# Aggregate loss
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.evaluate_metrics_aggregation_fn:
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No evaluate_metrics_aggregation_fn provided")

return loss_aggregated, metrics_aggregated

0 comments on commit d4be452

Please sign in to comment.