Skip to content

Commit

Permalink
Add in-place FedAvg (#2293)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jan 5, 2024
1 parent 2b4297d commit 625ae83
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 8 deletions.
28 changes: 27 additions & 1 deletion src/py/flwr/server/strategy/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

import numpy as np

from flwr.common import NDArray, NDArrays
from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays
from flwr.server.client_proxy import ClientProxy


def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
Expand All @@ -41,6 +42,31 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
return weights_prime


def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
"""Compute in-place weighted average."""
# Count total examples
num_examples_total = sum([fit_res.num_examples for _, fit_res in results])

# Compute scaling factors for each result
scaling_factors = [
fit_res.num_examples / num_examples_total for _, fit_res in results
]

# Let's do in-place aggregation
# Get first result, then add up each other
params = [
scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
]
for i, (_, fit_res) in enumerate(results[1:]):
res = (
scaling_factors[i + 1] * x
for x in parameters_to_ndarrays(fit_res.parameters)
)
params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]

return params


def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays:
"""Compute median."""
# Create a list of weights and ignore the number of examples
Expand Down
22 changes: 15 additions & 7 deletions src/py/flwr/server/strategy/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate, weighted_loss_avg
from .aggregate import aggregate, aggregate_inplace, weighted_loss_avg
from .strategy import Strategy

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
inplace: bool = True,
) -> None:
super().__init__()

Expand All @@ -128,6 +129,7 @@ def __init__(
self.initial_parameters = initial_parameters
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
self.inplace = inplace

def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
Expand Down Expand Up @@ -226,12 +228,18 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
if self.inplace:
# Does in-place weighted average of results
aggregated_ndarrays = aggregate_inplace(results)
else:
# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
aggregated_ndarrays = aggregate(weights_results)

parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays)

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
Expand Down
58 changes: 58 additions & 0 deletions src/py/flwr/server/strategy/fedavg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
"""FedAvg tests."""


from typing import List, Tuple, Union
from unittest.mock import MagicMock

import numpy as np
from numpy.testing import assert_allclose

from flwr.common import Code, FitRes, Status, parameters_to_ndarrays
from flwr.common.parameter import ndarrays_to_parameters
from flwr.server.client_proxy import ClientProxy

from .fedavg import FedAvg


Expand Down Expand Up @@ -120,3 +130,51 @@ def test_fedavg_num_evaluation_clients_minimum() -> None:

# Assert
assert expected == actual


def test_inplace_aggregate_fit_equivalence() -> None:
"""Test aggregate_fit equivalence between FedAvg and its inplace version."""
# Prepare
weights0_0 = np.random.randn(100, 64)
weights0_1 = np.random.randn(314, 628, 3)
weights1_0 = np.random.randn(100, 64)
weights1_1 = np.random.randn(314, 628, 3)

results: List[Tuple[ClientProxy, FitRes]] = [
(
MagicMock(),
FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=ndarrays_to_parameters([weights0_0, weights0_1]),
num_examples=1,
metrics={},
),
),
(
MagicMock(),
FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=ndarrays_to_parameters([weights1_0, weights1_1]),
num_examples=5,
metrics={},
),
),
]
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []

fedavg_reference = FedAvg(inplace=False)
fedavg_inplace = FedAvg()

# Execute
reference, _ = fedavg_reference.aggregate_fit(1, results, failures)
assert reference
inplace, _ = fedavg_inplace.aggregate_fit(1, results, failures)
assert inplace

# Convert to NumPy to check similarity
reference_np = parameters_to_ndarrays(reference)
inplace_np = parameters_to_ndarrays(inplace)

# Assert
for ref, inp in zip(reference_np, inplace_np):
assert_allclose(ref, inp)

0 comments on commit 625ae83

Please sign in to comment.