diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 63926f2eaa51..4eb76111b266 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -20,7 +20,8 @@ import numpy as np -from flwr.common import NDArray, NDArrays +from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays +from flwr.server.client_proxy import ClientProxy def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: @@ -41,6 +42,31 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: return weights_prime +def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: + """Compute in-place weighted average.""" + # Count total examples + num_examples_total = sum([fit_res.num_examples for _, fit_res in results]) + + # Compute scaling factors for each result + scaling_factors = [ + fit_res.num_examples / num_examples_total for _, fit_res in results + ] + + # Let's do in-place aggregation + # Get first result, then add up each other + params = [ + scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters) + ] + for i, (_, fit_res) in enumerate(results[1:]): + res = ( + scaling_factors[i + 1] * x + for x in parameters_to_ndarrays(fit_res.parameters) + ) + params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)] + + return params + + def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute median.""" # Create a list of weights and ignore the number of examples diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index c93c8cb8b83e..e4b126823fb6 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -37,7 +37,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy -from .aggregate import aggregate, weighted_loss_avg +from .aggregate import aggregate, aggregate_inplace, weighted_loss_avg from .strategy import Strategy WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ @@ -107,6 +107,7 @@ def __init__( initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + inplace: bool = True, ) -> None: super().__init__() @@ -128,6 +129,7 @@ def __init__( self.initial_parameters = initial_parameters self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn + self.inplace = inplace def __repr__(self) -> str: """Compute a string representation of the strategy.""" @@ -226,12 +228,18 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - for _, fit_res in results - ] - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + if self.inplace: + # Does in-place weighted average of results + aggregated_ndarrays = aggregate_inplace(results) + else: + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + aggregated_ndarrays = aggregate(weights_results) + + parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} diff --git a/src/py/flwr/server/strategy/fedavg_test.py b/src/py/flwr/server/strategy/fedavg_test.py index 947736f4a571..e62eaa5c5832 100644 --- a/src/py/flwr/server/strategy/fedavg_test.py +++ b/src/py/flwr/server/strategy/fedavg_test.py @@ -15,6 +15,16 @@ """FedAvg tests.""" +from typing import List, Tuple, Union +from unittest.mock import MagicMock + +import numpy as np +from numpy.testing import assert_allclose + +from flwr.common import Code, FitRes, Status, parameters_to_ndarrays +from flwr.common.parameter import ndarrays_to_parameters +from flwr.server.client_proxy import ClientProxy + from .fedavg import FedAvg @@ -120,3 +130,51 @@ def test_fedavg_num_evaluation_clients_minimum() -> None: # Assert assert expected == actual + + +def test_inplace_aggregate_fit_equivalence() -> None: + """Test aggregate_fit equivalence between FedAvg and its inplace version.""" + # Prepare + weights0_0 = np.random.randn(100, 64) + weights0_1 = np.random.randn(314, 628, 3) + weights1_0 = np.random.randn(100, 64) + weights1_1 = np.random.randn(314, 628, 3) + + results: List[Tuple[ClientProxy, FitRes]] = [ + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights0_0, weights0_1]), + num_examples=1, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights1_0, weights1_1]), + num_examples=5, + metrics={}, + ), + ), + ] + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + + fedavg_reference = FedAvg(inplace=False) + fedavg_inplace = FedAvg() + + # Execute + reference, _ = fedavg_reference.aggregate_fit(1, results, failures) + assert reference + inplace, _ = fedavg_inplace.aggregate_fit(1, results, failures) + assert inplace + + # Convert to NumPy to check similarity + reference_np = parameters_to_ndarrays(reference) + inplace_np = parameters_to_ndarrays(inplace) + + # Assert + for ref, inp in zip(reference_np, inplace_np): + assert_allclose(ref, inp)