Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MeaMed #1816

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
11 changes: 11 additions & 0 deletions doc/source/ref-api-flwr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ server.strategy.Krum
.. automethod:: __init__


.. _flwr-server-strategy-MeaMed-apiref:

server.strategy.MeaMed
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: flwr.server.strategy.MeaMed
:members:

.. automethod:: __init__


.. _flwr-server-strategy-FedXgbNnAvg-apiref:

server.strategy.FedXgbNnAvg
Expand Down
4 changes: 4 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@

Flower received many improvements under the hood, too many to list here.

- **Add new** `MeaMed` **strategy** ([#1816](https://github.com/adap/flower/pull/1816))

The new `MeaMed` strategy implements "mean around median" by [Xie et al., 2018](https://arxiv.org/pdf/1802.10116.pdf)

### Incompatible changes

- **Remove support for Python 3.7** ([#2280](https://github.com/adap/flower/pull/2280), [#2299](https://github.com/adap/flower/pull/2299), [#2304](https://github.com/adap/flower/pull/2304), [#2306](https://github.com/adap/flower/pull/2306), [#2355](https://github.com/adap/flower/pull/2355), [#2356](https://github.com/adap/flower/pull/2356))
Expand Down
25 changes: 25 additions & 0 deletions src/py/flwr/server/strategy/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,31 @@ def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays:
return median_w


def aggregate_meamed(results: List[Tuple[NDArrays, int]], to_exclude: int) -> NDArrays:
"""Compute the mean around the median."""
# Number of models to aggregate
to_keep = len(results) - to_exclude

# Compute median vector model
median_model = aggregate_median(results)

# Compute the distance matrix between the median model and the other models
closest_models = []
distance_matrix = _compute_distances([median_model] + [w for w, _ in results])

# Take the to_keep closest parameters around median_model. We ignore the first
# element of the distance matrix because it is the distance between the median
# model and itself.
closest_models.append(
np.argsort(distance_matrix[0])[1 : to_keep + 1].tolist() # noqa: E203
)

closest_w = [(results[i - 1][0], results[i - 1][1]) for i in closest_models[0]]

# Compute the average of the to_keep closest parameters
return aggregate(closest_w)


def aggregate_krum(
results: List[Tuple[NDArrays, int]], num_malicious: int, to_keep: int
) -> NDArrays:
Expand Down
26 changes: 25 additions & 1 deletion src/py/flwr/server/strategy/aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from .aggregate import aggregate, weighted_loss_avg
from .aggregate import aggregate, aggregate_meamed, weighted_loss_avg


def test_aggregate() -> None:
Expand Down Expand Up @@ -64,3 +64,27 @@ def test_weighted_loss_avg_multiple_values() -> None:

# Assert
assert expected == actual


def test_aggregate_meamed() -> None:
"""Test mean around median aggregation."""
weights0 = [np.array([1, 6, 11]), np.array([16, 21, 26])]
weights1 = [np.array([2, 7, 12]), np.array([17, 22, 27])]
weights3 = [np.array([3, 8, 13]), np.array([18, 23, 28])]
weights4 = [np.array([4, 9, 14]), np.array([19, 24, 29])]
weights5 = [np.array([5, 10, 15]), np.array([20, 25, 30])]

results = [
(weights0, 1),
(weights1, 1),
(weights3, 1),
(weights4, 1),
(weights5, 1),
]
expected = [np.array([3.0, 8.0, 13.0]), np.array([18.0, 23.0, 28.0])]

# Execute
actual = aggregate_meamed(results, 2)

# Assert
np.testing.assert_equal(expected, actual)
146 changes: 146 additions & 0 deletions src/py/flwr/server/strategy/meamed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generalized Byzantine-tolerant SGD.

[Xie et al., 2018].

Paper: https://arxiv.org/pdf/1802.10116.pdf
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate_meamed
from .fedavg import FedAvg


class MeaMed(FedAvg):
"""Configurable MeaMed strategy implementation."""

# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
num_clients_to_exclude: int = 0,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
"""MeaMed strategy.

Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. Defaults to 0.1.
fraction_evaluate : float, optional
Fraction of clients used during validation. Defaults to 0.1.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
num_clients_to_exclude : int, optional
Number of clients to exclude before averaging. Defaults to 0, in that case it is equivalent to FedAvg.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
"""
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
self.num_clients_to_exclude = num_clients_to_exclude

def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
rep = f"MeaMed(accept_failures={self.accept_failures})"
return rep

def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using MeaMed."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(
aggregate_meamed(weights_results, self.num_clients_to_exclude)
)

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")

return parameters_aggregated, metrics_aggregated
Loading