From a04455b7dd1abd39a09b2ea37e1bc27df0785659 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 29 Jan 2024 16:53:47 +0000 Subject: [PATCH 1/9] del SA folder in client --- .../secure_aggregation/__init__.py | 10 +- .../secaggplus_middleware.py} | 222 +++++++++--------- .../secaggplus_middleware_test.py} | 138 ++++++----- .../flwr/client/secure_aggregation/handler.py | 43 ---- 4 files changed, 199 insertions(+), 214 deletions(-) rename src/py/flwr/client/{ => middleware}/secure_aggregation/__init__.py (76%) rename src/py/flwr/client/{secure_aggregation/secaggplus_handler.py => middleware/secure_aggregation/secaggplus_middleware.py} (73%) rename src/py/flwr/client/{secure_aggregation/secaggplus_handler_test.py => middleware/secure_aggregation/secaggplus_middleware_test.py} (71%) delete mode 100644 src/py/flwr/client/secure_aggregation/handler.py diff --git a/src/py/flwr/client/secure_aggregation/__init__.py b/src/py/flwr/client/middleware/secure_aggregation/__init__.py similarity index 76% rename from src/py/flwr/client/secure_aggregation/__init__.py rename to src/py/flwr/client/middleware/secure_aggregation/__init__.py index 37c816a390de..b9da30e42b34 100644 --- a/src/py/flwr/client/secure_aggregation/__init__.py +++ b/src/py/flwr/client/middleware/secure_aggregation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,8 @@ # limitations under the License. # ============================================================================== """Secure Aggregation handlers.""" - - -from .handler import SecureAggregationHandler -from .secaggplus_handler import SecAggPlusHandler +from .secaggplus_middleware import secaggplus_middleware __all__ = [ - "SecAggPlusHandler", - "SecureAggregationHandler", + "secaggplus_middleware", ] diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py similarity index 73% rename from src/py/flwr/client/secure_aggregation/secaggplus_handler.py rename to src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index 4b74c1ace3de..2ac3dce90d73 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,18 +17,18 @@ import os from dataclasses import dataclass, field -from logging import ERROR, INFO, WARNING -from typing import Any, Dict, List, Optional, Tuple, Union, cast - -from flwr.client.client import Client -from flwr.client.numpy_client import NumPyClient -from flwr.common import ( - bytes_to_ndarray, - ndarray_to_bytes, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from logging import INFO, WARNING +from typing import Any, Callable, Dict, List, Tuple, cast + +from flwr.client.typing import FlowerCallable +from flwr.common import ndarray_to_bytes, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import TASK_TYPE_FIT +from flwr.common.context import Context from flwr.common.logger import log +from flwr.common.message import Message, Metadata +from flwr.common.recordset import RecordSet from flwr.common.secure_aggregation.crypto.shamir import create_shares from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_private_key, @@ -56,7 +56,6 @@ KEY_DESTINATION_LIST, KEY_MASKED_PARAMETERS, KEY_MOD_RANGE, - KEY_PARAMETERS, KEY_PUBLIC_KEY_1, KEY_PUBLIC_KEY_2, KEY_SAMPLE_NUMBER, @@ -68,6 +67,8 @@ KEY_STAGE, KEY_TARGET_RANGE, KEY_THRESHOLD, + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, STAGE_COLLECT_MASKED_INPUT, STAGE_SETUP, STAGE_SHARE_KEYS, @@ -79,9 +80,7 @@ share_keys_plaintext_concat, share_keys_plaintext_separate, ) -from flwr.common.typing import FitIns, Value - -from .handler import SecureAggregationHandler +from flwr.common.typing import ConfigsRecordValues, FitRes @dataclass @@ -89,6 +88,8 @@ class SecAggPlusState: """State of the SecAgg+ protocol.""" + current_stage: str = STAGE_UNMASK + sid: int = 0 sample_num: int = 0 share_num: int = 0 @@ -112,70 +113,85 @@ class SecAggPlusState: ss2_dict: Dict[int, bytes] = field(default_factory=dict) public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) - client: Optional[Union[Client, NumPyClient]] = None +def _get_fit_fn( + msg: Message, call_next: FlowerCallable, ctxt: Context +) -> Callable[[], FitRes]: + """Get the fit function.""" -class SecAggPlusHandler(SecureAggregationHandler): - """Message handler for the SecAgg+ protocol.""" + def fit() -> FitRes: + out_msg = call_next(msg, ctxt) + return compat.recordset_to_fitres(out_msg.message, keep_input=False) - _shared_state = SecAggPlusState() - _current_stage = STAGE_UNMASK + return fit - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming message and return results, following the SecAgg+ protocol. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. - - Returns - ------- - Dict[str, Value] - The final/intermediate results of the SecAgg+ protocol. - """ - # Check if self is a client - if not isinstance(self, (Client, NumPyClient)): - raise TypeError( - "The subclass of SecAggPlusHandler must be " - "the subclass of Client or NumPyClient." - ) - # Check the validity of the next stage - check_stage(self._current_stage, named_values) +def secaggplus_middleware( + msg: Message, + call_next: FlowerCallable, + ctxt: Context, +) -> Dict[str, ConfigsRecordValues]: + """Handle incoming message and return results, following the SecAgg+ protocol.""" + # Ignore non-fit messages + if msg.metadata.task_type != TASK_TYPE_FIT: + return call_next(msg, ctxt) - # Update the current stage - self._current_stage = cast(str, named_values.pop(KEY_STAGE)) + # Retrieve state + if RECORD_KEY_STATE not in ctxt.state.configs: + ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord({})) + state = SecAggPlusState(**ctxt.state.get_configs(RECORD_KEY_STATE).data) - # Check the validity of the `named_values` based on the current stage - check_named_values(self._current_stage, named_values) + # Retrieve configs + configs = msg.message.get_configs(RECORD_KEY_CONFIGS).data - # Execute - if self._current_stage == STAGE_SETUP: - self._shared_state = SecAggPlusState(client=self) - return _setup(self._shared_state, named_values) - if self._current_stage == STAGE_SHARE_KEYS: - return _share_keys(self._shared_state, named_values) - if self._current_stage == STAGE_COLLECT_MASKED_INPUT: - return _collect_masked_input(self._shared_state, named_values) - if self._current_stage == STAGE_UNMASK: - return _unmask(self._shared_state, named_values) - raise ValueError(f"Unknown secagg stage: {self._current_stage}") + # Check the validity of the next stage + check_stage(state.current_stage, configs) + + # Update the current stage + state.current_stage = cast(str, configs.pop(KEY_STAGE)) + + # Check the validity of the `named_values` based on the current stage + check_configs(state.current_stage, configs) + + # Execute + if state.current_stage == STAGE_SETUP: + res = _setup(state, configs) + elif state.current_stage == STAGE_SHARE_KEYS: + res = _share_keys(state, configs) + elif state.current_stage == STAGE_COLLECT_MASKED_INPUT: + fit = _get_fit_fn(msg, call_next, ctxt) + res = _collect_masked_input(state, configs, fit) + elif state.current_stage == STAGE_UNMASK: + res = _unmask(state, configs) + else: + raise ValueError(f"Unknown secagg stage: {state.current_stage}") + + # Save state + ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(vars(state))) + + # Return message + return Message( + metadata=Metadata( + run_id="", + task_id="", + group_id="", + ttl="", + task_type=TASK_TYPE_FIT, + ), + message=RecordSet(configs={RECORD_KEY_CONFIGS, ConfigsRecord(res, False)}), + ) -def check_stage(current_stage: str, named_values: Dict[str, Value]) -> None: +def check_stage(current_stage: str, configs: Dict[str, ConfigsRecordValues]) -> None: """Check the validity of the next stage.""" # Check the existence of KEY_STAGE - if KEY_STAGE not in named_values: + if KEY_STAGE not in configs: raise KeyError( f"The required key '{KEY_STAGE}' is missing from the input `named_values`." ) # Check the value type of the KEY_STAGE - next_stage = named_values[KEY_STAGE] + next_stage = configs[KEY_STAGE] if not isinstance(next_stage, str): raise TypeError( f"The value for the key '{KEY_STAGE}' must be of type {str}, " @@ -198,8 +214,8 @@ def check_stage(current_stage: str, named_values: Dict[str, Value]) -> None: # pylint: disable-next=too-many-branches -def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: - """Check the validity of the input `named_values`.""" +def check_configs(stage: str, configs: Dict[str, ConfigsRecordValues]) -> None: + """Check the validity of the configs.""" # Check `named_values` for the setup stage if stage == STAGE_SETUP: key_type_pairs = [ @@ -212,7 +228,7 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: (KEY_MOD_RANGE, int), ] for key, expected_type in key_type_pairs: - if key not in named_values: + if key not in configs: raise KeyError( f"Stage {STAGE_SETUP}: the required key '{key}' is " "missing from the input `named_values`." @@ -220,14 +236,14 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: # Bool is a subclass of int in Python, # so `isinstance(v, int)` will return True even if v is a boolean. # pylint: disable-next=unidiomatic-typecheck - if type(named_values[key]) is not expected_type: + if type(configs[key]) is not expected_type: raise TypeError( f"Stage {STAGE_SETUP}: The value for the key '{key}' " f"must be of type {expected_type}, " - f"but got {type(named_values[key])} instead." + f"but got {type(configs[key])} instead." ) elif stage == STAGE_SHARE_KEYS: - for key, value in named_values.items(): + for key, value in configs.items(): if ( not isinstance(value, list) or len(value) != 2 @@ -242,18 +258,17 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: key_type_pairs = [ (KEY_CIPHERTEXT_LIST, bytes), (KEY_SOURCE_LIST, int), - (KEY_PARAMETERS, bytes), ] for key, expected_type in key_type_pairs: - if key not in named_values: + if key not in configs: raise KeyError( f"Stage {STAGE_COLLECT_MASKED_INPUT}: " f"the required key '{key}' is " "missing from the input `named_values`." ) - if not isinstance(named_values[key], list) or any( + if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], named_values[key]) + for elm in cast(List[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -268,15 +283,15 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: (KEY_DEAD_SECURE_ID_LIST, int), ] for key, expected_type in key_type_pairs: - if key not in named_values: + if key not in configs: raise KeyError( f"Stage {STAGE_UNMASK}: " f"the required key '{key}' is " "missing from the input `named_values`." ) - if not isinstance(named_values[key], list) or any( + if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], named_values[key]) + for elm in cast(List[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -289,9 +304,11 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: raise ValueError(f"Unknown secagg stage: {stage}") -def _setup(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: +def _setup( + state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] +) -> Dict[str, ConfigsRecordValues]: # Assigning parameter values to object fields - sec_agg_param_dict = named_values + sec_agg_param_dict = configs state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER]) state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID]) log(INFO, "Client %d: starting stage 0...", state.sid) @@ -324,8 +341,8 @@ def _setup(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, # pylint: disable-next=too-many-locals def _share_keys( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: + state: SecAggPlusState, named_values: Dict[str, ConfigsRecordValues] +) -> Dict[str, ConfigsRecordValues]: named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], named_values) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} log(INFO, "Client %d: starting stage 1...", state.sid) @@ -333,7 +350,7 @@ def _share_keys( # Check if the size is larger than threshold if len(state.public_keys_dict) < state.threshold: - raise ValueError("Available neighbours number smaller than threshold") + raise Exception("Available neighbours number smaller than threshold") # Check if all public keys are unique pk_list: List[bytes] = [] @@ -341,14 +358,14 @@ def _share_keys( pk_list.append(pk1) pk_list.append(pk2) if len(set(pk_list)) != len(pk_list): - raise ValueError("Some public keys are identical") + raise Exception("Some public keys are identical") # Check if public keys of this client are correct in the dictionary if ( state.public_keys_dict[state.sid][0] != state.pk1 or state.public_keys_dict[state.sid][1] != state.pk2 ): - raise ValueError( + raise Exception( "Own public keys are displayed in dict incorrectly, should not happen!" ) @@ -386,14 +403,16 @@ def _share_keys( # pylint: disable-next=too-many-locals def _collect_masked_input( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: + state: SecAggPlusState, + configs: Dict[str, ConfigsRecordValues], + fit: Callable[[], FitRes], +) -> Dict[str, ConfigsRecordValues]: log(INFO, "Client %d: starting stage 2...", state.sid) available_clients: List[int] = [] - ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) - srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) + ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST]) + srcs = cast(List[int], configs[KEY_SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: - raise ValueError("Not enough available neighbour clients.") + raise Exception("Not enough available neighbour clients.") # Decrypt ciphertexts, verify their sources, and store shares. for src, ciphertext in zip(srcs, ciphertexts): @@ -409,7 +428,7 @@ def _collect_masked_input( f"from {actual_src} instead of {src}." ) if dst != state.sid: - raise ValueError( + ValueError( f"Client {state.sid}: received an encrypted message" f"for Client {dst} from Client {src}." ) @@ -417,18 +436,9 @@ def _collect_masked_input( state.sk1_share_dict[src] = sk1_share # Fit client - parameters_bytes = cast(List[bytes], named_values[KEY_PARAMETERS]) - parameters = [bytes_to_ndarray(w) for w in parameters_bytes] - if isinstance(state.client, Client): - fit_res = state.client.fit( - FitIns(parameters=ndarrays_to_parameters(parameters), config={}) - ) - parameters_factor = fit_res.num_examples - parameters = parameters_to_ndarrays(fit_res.parameters) - elif isinstance(state.client, NumPyClient): - parameters, parameters_factor, _ = state.client.fit(parameters, {}) - else: - log(ERROR, "Client %d: fit function is missing.", state.sid) + fit_res = fit() + parameters_factor = fit_res.num_examples + parameters = parameters_to_ndarrays(fit_res.parameters) # Quantize parameter update (vector) quantized_parameters = quantize( @@ -468,15 +478,17 @@ def _collect_masked_input( } -def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: +def _unmask( + state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] +) -> Dict[str, ConfigsRecordValues]: log(INFO, "Client %d: starting stage 3...", state.sid) - active_sids = cast(List[int], named_values[KEY_ACTIVE_SECURE_ID_LIST]) - dead_sids = cast(List[int], named_values[KEY_DEAD_SECURE_ID_LIST]) + active_sids = cast(List[int], configs[KEY_ACTIVE_SECURE_ID_LIST]) + dead_sids = cast(List[int], configs[KEY_DEAD_SECURE_ID_LIST]) # Send private mask seed share for every avaliable client (including itclient) # Send first private key share for building pairwise mask for every dropped client if len(active_sids) < state.threshold: - raise ValueError("Available neighbours number smaller than threshold") + raise Exception("Available neighbours number smaller than threshold") sids, shares = [], [] sids += active_sids diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py similarity index 71% rename from src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py rename to src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py index 9693a46af989..6a50c1c6ffe8 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,16 +16,16 @@ import unittest from itertools import product -from typing import Any, Dict, List, cast +from typing import Callable, Dict, List -from flwr.client import NumPyClient +from flwr.client.middleware import make_ffn +from flwr.common import serde from flwr.common.secure_aggregation.secaggplus_constants import ( KEY_ACTIVE_SECURE_ID_LIST, KEY_CIPHERTEXT_LIST, KEY_CLIPPING_RANGE, KEY_DEAD_SECURE_ID_LIST, KEY_MOD_RANGE, - KEY_PARAMETERS, KEY_SAMPLE_NUMBER, KEY_SECURE_ID, KEY_SHARE_NUMBER, @@ -40,27 +40,47 @@ STAGES, ) from flwr.common.typing import Value +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes + +from .secaggplus_middleware import SecAggPlusState, check_configs + + +def get_test_handler( + state: SecAggPlusState, +) -> Callable[[Dict[str, Value]], Dict[str, Value]]: + """.""" + + def empty_ffn(_: Fwd) -> Bwd: + return Bwd(task_res=TaskRes(), state=WorkloadState(state={})) + + app = make_ffn(empty_ffn, [secaggplus_middleware]) + workload_state = WorkloadState(state={KEY_SECAGGPLUS_STATE: state}) # type: ignore + + def func(named_values: Dict[str, Value]) -> Dict[str, Value]: + bwd = app( + Fwd( + task_ins=TaskIns( + task=Task( + sa=SecureAggregation( + named_values=serde.named_values_to_proto(named_values) + ) + ) + ), + state=workload_state, + ) + ) + return serde.named_values_from_proto(bwd.task_res.task.sa.named_values) -from .secaggplus_handler import SecAggPlusHandler, check_named_values - - -class EmptyFlowerNumPyClient(NumPyClient, SecAggPlusHandler): - """Empty NumPyClient.""" + return func class TestSecAggPlusHandler(unittest.TestCase): """Test the SecAgg+ protocol handler.""" - def test_invalid_handler(self) -> None: - """Test invalid handler.""" - handler = SecAggPlusHandler() - - with self.assertRaises(TypeError): - handler.handle_secure_aggregation({}) - def test_stage_transition(self) -> None: """Test stage transition.""" - handler = EmptyFlowerNumPyClient() + state = SecAggPlusState() + handler = get_test_handler(state) assert STAGES == ( STAGE_SETUP, @@ -88,28 +108,27 @@ def test_stage_transition(self) -> None: # If the next stage is valid, the function should update the current stage # and then raise KeyError or other exceptions when trying to execute SA. for current_stage, next_stage in valid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage + state.current_stage = current_stage with self.assertRaises(KeyError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == next_stage + handler({KEY_STAGE: next_stage}) + + assert state.current_stage == next_stage # Test invalid transitions # If the next stage is invalid, the function should raise ValueError for current_stage, next_stage in invalid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage + state.current_stage = current_stage with self.assertRaises(ValueError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == current_stage + handler({KEY_STAGE: next_stage}) + + assert state.current_stage == current_stage def test_stage_setup_check(self) -> None: """Test content checking for the setup stage.""" - handler = EmptyFlowerNumPyClient() + state = SecAggPlusState() + handler = get_test_handler(state) valid_key_type_pairs = [ (KEY_SAMPLE_NUMBER, int), @@ -136,7 +155,7 @@ def test_stage_setup_check(self) -> None: # Test valid `named_values` try: - check_named_values(STAGE_SETUP, valid_named_values.copy()) + check_configs(STAGE_SETUP, valid_named_values.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") @@ -153,21 +172,22 @@ def test_stage_setup_check(self) -> None: if other_type == value_type: continue invalid_named_values[key] = other_value - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK + + state.current_stage = STAGE_UNMASK with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) + handler(invalid_named_values.copy()) # Test missing key invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK + + state.current_stage = STAGE_UNMASK with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values.copy()) + handler(invalid_named_values.copy()) def test_stage_share_keys_check(self) -> None: """Test content checking for the share keys stage.""" - handler = EmptyFlowerNumPyClient() + state = SecAggPlusState() + handler = get_test_handler(state) valid_named_values: Dict[str, Value] = { "1": [b"public key 1", b"public key 2"], @@ -177,7 +197,7 @@ def test_stage_share_keys_check(self) -> None: # Test valid `named_values` try: - check_named_values(STAGE_SHARE_KEYS, valid_named_values.copy()) + check_configs(STAGE_SHARE_KEYS, valid_named_values.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") @@ -196,24 +216,23 @@ def test_stage_share_keys_check(self) -> None: invalid_named_values = valid_named_values.copy() invalid_named_values["1"] = value - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SETUP + state.current_stage = STAGE_SETUP with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) + handler(invalid_named_values.copy()) def test_stage_collect_masked_input_check(self) -> None: """Test content checking for the collect masked input stage.""" - handler = EmptyFlowerNumPyClient() + state = SecAggPlusState() + handler = get_test_handler(state) valid_named_values: Dict[str, Value] = { KEY_CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], KEY_SOURCE_LIST: [32, 51324, 32324123, -3], - KEY_PARAMETERS: [b"params1", b"params2"], } # Test valid `named_values` try: - check_named_values(STAGE_COLLECT_MASKED_INPUT, valid_named_values.copy()) + check_configs(STAGE_COLLECT_MASKED_INPUT, valid_named_values.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") @@ -228,25 +247,26 @@ def test_stage_collect_masked_input_check(self) -> None: continue invalid_named_values = valid_named_values.copy() invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS + + state.current_stage = STAGE_SHARE_KEYS with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_named_values) # Test wrong value type for the key for key in valid_named_values: if key == KEY_STAGE: continue invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(3.1415926) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS + invalid_named_values[key] = [3.1415926] + + state.current_stage = STAGE_SHARE_KEYS with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_named_values) def test_stage_unmask_check(self) -> None: """Test content checking for the unmasking stage.""" - handler = EmptyFlowerNumPyClient() + state = SecAggPlusState() + handler = get_test_handler(state) valid_named_values: Dict[str, Value] = { KEY_ACTIVE_SECURE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], @@ -255,7 +275,7 @@ def test_stage_unmask_check(self) -> None: # Test valid `named_values` try: - check_named_values(STAGE_UNMASK, valid_named_values.copy()) + check_configs(STAGE_UNMASK, valid_named_values.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") @@ -270,18 +290,18 @@ def test_stage_unmask_check(self) -> None: continue invalid_named_values = valid_named_values.copy() invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT + + state.current_stage = STAGE_COLLECT_MASKED_INPUT with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_named_values) # Test wrong value type for the key for key in valid_named_values: if key == KEY_STAGE: continue invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(True) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT + invalid_named_values[key] = [True, False, True, False] + + state.current_stage = STAGE_COLLECT_MASKED_INPUT with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_named_values) diff --git a/src/py/flwr/client/secure_aggregation/handler.py b/src/py/flwr/client/secure_aggregation/handler.py deleted file mode 100644 index 487ed842c93f..000000000000 --- a/src/py/flwr/client/secure_aggregation/handler.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Message Handler for Secure Aggregation (abstract base class).""" - - -from abc import ABC, abstractmethod -from typing import Dict - -from flwr.common.typing import Value - - -class SecureAggregationHandler(ABC): - """Abstract base class for secure aggregation message handlers.""" - - @abstractmethod - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming Secure Aggregation message and return results. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. - - Returns - ------- - Dict[str, Value] - The final/intermediate results of the Secure Aggregation protocol. - """ From 71fefb60bf4d7efaaa12b1d15f44d5005f3c02cb Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 29 Jan 2024 16:55:00 +0000 Subject: [PATCH 2/9] update secaggplus consts --- src/py/flwr/common/secure_aggregation/secaggplus_constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py index 8dd21a6016f1..9a2bf26e98e8 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py @@ -14,6 +14,8 @@ # ============================================================================== """Constants for the SecAgg/SecAgg+ protocol.""" +RECORD_KEY_STATE = "secaggplus_state" +RECORD_KEY_CONFIGS = "secaggplus_configs" # Names of stages STAGE_SETUP = "setup" From 2a0b0856deb0abe5a778bf43bb98ee6a66e07258 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 30 Jan 2024 16:49:45 +0000 Subject: [PATCH 3/9] sa mw --- src/py/flwr/client/middleware/__init__.py | 2 + .../secaggplus_middleware.py | 69 +++++-- .../secaggplus_middleware_test.py | 186 ++++++++++-------- 3 files changed, 153 insertions(+), 104 deletions(-) diff --git a/src/py/flwr/client/middleware/__init__.py b/src/py/flwr/client/middleware/__init__.py index 58b31296fbbe..2cab6e61899c 100644 --- a/src/py/flwr/client/middleware/__init__.py +++ b/src/py/flwr/client/middleware/__init__.py @@ -15,8 +15,10 @@ """Middleware layers.""" +from .secure_aggregation.secaggplus_middleware import secaggplus_middleware from .utils import make_ffn __all__ = [ "make_ffn", + "secaggplus_middleware", ] diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index 2ac3dce90d73..e6376abb90d8 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -113,6 +113,40 @@ class SecAggPlusState: ss2_dict: Dict[int, bytes] = field(default_factory=dict) public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) + def __init__(self, **kwargs: ConfigsRecordValues) -> None: + for k, v in kwargs.items(): + if k.endswith(":V"): + continue + new_v: Any = v + if k.endswith(":K"): + keys = cast(List[int], v) + values = cast(List[bytes], kwargs[k[:-2] + ":V"]) + if len(values) > len(keys): + updated_values = [ + tuple(values[i : i + 2]) for i in range(0, len(values), 2) + ] + new_v = dict(zip(keys, updated_values)) + else: + new_v = dict(zip(keys, values)) + self.__setattr__(k, new_v) + + def to_dict(self) -> Dict[str, ConfigsRecordValues]: + """Convert the state to a dictionary.""" + ret = vars(self) + for k in ret.keys(): + if isinstance(ret[k], dict): + # Replace dict with two lists + v = cast(Dict[str, Any], ret.pop(k)) + ret[f"{k}:K"] = list(v.keys()) + if k == "public_keys_dict": + v_list: List[bytes] = [] + for b1_b2 in cast(List[Tuple[bytes, bytes]], v.values()): + v_list.extend(b1_b2) + ret[f"{k}:V"] = v_list + else: + ret[f"{k}:V"] = list(v.values()) + return ret + def _get_fit_fn( msg: Message, call_next: FlowerCallable, ctxt: Context @@ -128,9 +162,9 @@ def fit() -> FitRes: def secaggplus_middleware( msg: Message, - call_next: FlowerCallable, ctxt: Context, -) -> Dict[str, ConfigsRecordValues]: + call_next: FlowerCallable, +) -> Message: """Handle incoming message and return results, following the SecAgg+ protocol.""" # Ignore non-fit messages if msg.metadata.task_type != TASK_TYPE_FIT: @@ -139,7 +173,8 @@ def secaggplus_middleware( # Retrieve state if RECORD_KEY_STATE not in ctxt.state.configs: ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord({})) - state = SecAggPlusState(**ctxt.state.get_configs(RECORD_KEY_STATE).data) + state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data + state = SecAggPlusState(**state_dict) # Retrieve configs configs = msg.message.get_configs(RECORD_KEY_CONFIGS).data @@ -150,7 +185,7 @@ def secaggplus_middleware( # Update the current stage state.current_stage = cast(str, configs.pop(KEY_STAGE)) - # Check the validity of the `named_values` based on the current stage + # Check the validity of the configs based on the current stage check_configs(state.current_stage, configs) # Execute @@ -171,14 +206,8 @@ def secaggplus_middleware( # Return message return Message( - metadata=Metadata( - run_id="", - task_id="", - group_id="", - ttl="", - task_type=TASK_TYPE_FIT, - ), - message=RecordSet(configs={RECORD_KEY_CONFIGS, ConfigsRecord(res, False)}), + metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}), ) @@ -341,16 +370,16 @@ def _setup( # pylint: disable-next=too-many-locals def _share_keys( - state: SecAggPlusState, named_values: Dict[str, ConfigsRecordValues] + state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] ) -> Dict[str, ConfigsRecordValues]: - named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], named_values) + named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} log(INFO, "Client %d: starting stage 1...", state.sid) state.public_keys_dict = key_dict # Check if the size is larger than threshold if len(state.public_keys_dict) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique pk_list: List[bytes] = [] @@ -358,14 +387,14 @@ def _share_keys( pk_list.append(pk1) pk_list.append(pk2) if len(set(pk_list)) != len(pk_list): - raise Exception("Some public keys are identical") + raise ValueError("Some public keys are identical") # Check if public keys of this client are correct in the dictionary if ( state.public_keys_dict[state.sid][0] != state.pk1 or state.public_keys_dict[state.sid][1] != state.pk2 ): - raise Exception( + raise ValueError( "Own public keys are displayed in dict incorrectly, should not happen!" ) @@ -412,7 +441,7 @@ def _collect_masked_input( ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST]) srcs = cast(List[int], configs[KEY_SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: - raise Exception("Not enough available neighbour clients.") + raise ValueError("Not enough available neighbour clients.") # Decrypt ciphertexts, verify their sources, and store shares. for src, ciphertext in zip(srcs, ciphertexts): @@ -428,7 +457,7 @@ def _collect_masked_input( f"from {actual_src} instead of {src}." ) if dst != state.sid: - ValueError( + raise ValueError( f"Client {state.sid}: received an encrypted message" f"for Client {dst} from Client {src}." ) @@ -488,7 +517,7 @@ def _unmask( # Send private mask seed share for every avaliable client (including itclient) # Send first private key share for building pairwise mask for every dropped client if len(active_sids) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") sids, shares = [], [] sids += active_sids diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py index 6a50c1c6ffe8..8ec52d71cbdd 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py @@ -19,7 +19,11 @@ from typing import Callable, Dict, List from flwr.client.middleware import make_ffn -from flwr.common import serde +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import TASK_TYPE_FIT +from flwr.common.context import Context +from flwr.common.message import Message, Metadata +from flwr.common.recordset import RecordSet from flwr.common.secure_aggregation.secaggplus_constants import ( KEY_ACTIVE_SECURE_ID_LIST, KEY_CIPHERTEXT_LIST, @@ -33,54 +37,68 @@ KEY_STAGE, KEY_TARGET_RANGE, KEY_THRESHOLD, + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, STAGE_COLLECT_MASKED_INPUT, STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_UNMASK, STAGES, ) -from flwr.common.typing import Value -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes +from flwr.common.typing import ConfigsRecordValues -from .secaggplus_middleware import SecAggPlusState, check_configs +from .secaggplus_middleware import SecAggPlusState, check_configs, secaggplus_middleware def get_test_handler( - state: SecAggPlusState, -) -> Callable[[Dict[str, Value]], Dict[str, Value]]: + ctxt: Context, +) -> Callable[[Dict[str, ConfigsRecordValues]], Dict[str, ConfigsRecordValues]]: """.""" - def empty_ffn(_: Fwd) -> Bwd: - return Bwd(task_res=TaskRes(), state=WorkloadState(state={})) + def empty_ffn(_: Message, _2: Context) -> Message: + return Message( + metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + message=RecordSet(), + ) app = make_ffn(empty_ffn, [secaggplus_middleware]) - workload_state = WorkloadState(state={KEY_SECAGGPLUS_STATE: state}) # type: ignore - - def func(named_values: Dict[str, Value]) -> Dict[str, Value]: - bwd = app( - Fwd( - task_ins=TaskIns( - task=Task( - sa=SecureAggregation( - named_values=serde.named_values_to_proto(named_values) - ) - ) - ), - state=workload_state, - ) + + def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]: + in_msg = Message( + metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}), ) - return serde.named_values_from_proto(bwd.task_res.task.sa.named_values) + out_msg = app(in_msg, ctxt) + return out_msg.message.get_configs(RECORD_KEY_CONFIGS).data return func +def _make_ctxt() -> Context: + cfg = ConfigsRecord(SecAggPlusState().to_dict()) + return Context(RecordSet(configs={RECORD_KEY_STATE: cfg})) + + +def _make_set_state_fn( + ctxt: Context, +) -> Callable[[str], None]: + def set_stage(stage: str) -> None: + state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data + state = SecAggPlusState(**state_dict) + state.current_stage = stage + ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict())) + + return set_stage + + class TestSecAggPlusHandler(unittest.TestCase): """Test the SecAgg+ protocol handler.""" def test_stage_transition(self) -> None: """Test stage transition.""" - state = SecAggPlusState() - handler = get_test_handler(state) + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) assert STAGES == ( STAGE_SETUP, @@ -108,27 +126,24 @@ def test_stage_transition(self) -> None: # If the next stage is valid, the function should update the current stage # and then raise KeyError or other exceptions when trying to execute SA. for current_stage, next_stage in valid_transitions: - state.current_stage = current_stage + set_stage(current_stage) with self.assertRaises(KeyError): handler({KEY_STAGE: next_stage}) - assert state.current_stage == next_stage - # Test invalid transitions # If the next stage is invalid, the function should raise ValueError for current_stage, next_stage in invalid_transitions: - state.current_stage = current_stage + set_stage(current_stage) with self.assertRaises(ValueError): handler({KEY_STAGE: next_stage}) - assert state.current_stage == current_stage - def test_stage_setup_check(self) -> None: """Test content checking for the setup stage.""" - state = SecAggPlusState() - handler = get_test_handler(state) + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) valid_key_type_pairs = [ (KEY_SAMPLE_NUMBER, int), @@ -140,7 +155,7 @@ def test_stage_setup_check(self) -> None: (KEY_MOD_RANGE, int), ] - type_to_test_value: Dict[type, Value] = { + type_to_test_value: Dict[type, ConfigsRecordValues] = { int: 10, bool: True, float: 1.0, @@ -148,48 +163,49 @@ def test_stage_setup_check(self) -> None: bytes: b"test", } - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { key: type_to_test_value[value_type] for key, value_type in valid_key_type_pairs } # Test valid `named_values` try: - check_configs(STAGE_SETUP, valid_named_values.copy()) + check_configs(STAGE_SETUP, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SETUP + valid_configs[KEY_STAGE] = STAGE_SETUP # Test invalid `named_values` for key, value_type in valid_key_type_pairs: - invalid_named_values = valid_named_values.copy() + invalid_configs = valid_configs.copy() # Test wrong value type for the key for other_type, other_value in type_to_test_value.items(): if other_type == value_type: continue - invalid_named_values[key] = other_value + invalid_configs[key] = other_value - state.current_stage = STAGE_UNMASK + set_stage(STAGE_UNMASK) with self.assertRaises(TypeError): - handler(invalid_named_values.copy()) + handler(invalid_configs.copy()) # Test missing key - invalid_named_values.pop(key) + invalid_configs.pop(key) - state.current_stage = STAGE_UNMASK + set_stage(STAGE_UNMASK) with self.assertRaises(KeyError): - handler(invalid_named_values.copy()) + handler(invalid_configs.copy()) def test_stage_share_keys_check(self) -> None: """Test content checking for the share keys stage.""" - state = SecAggPlusState() - handler = get_test_handler(state) + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { "1": [b"public key 1", b"public key 2"], "2": [b"public key 1", b"public key 2"], "3": [b"public key 1", b"public key 2"], @@ -197,111 +213,113 @@ def test_stage_share_keys_check(self) -> None: # Test valid `named_values` try: - check_configs(STAGE_SHARE_KEYS, valid_named_values.copy()) + check_configs(STAGE_SHARE_KEYS, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SHARE_KEYS + valid_configs[KEY_STAGE] = STAGE_SHARE_KEYS # Test invalid `named_values` - invalid_values: List[Value] = [ + invalid_values: List[ConfigsRecordValues] = [ b"public key 1", [b"public key 1"], [b"public key 1", b"public key 2", b"public key 3"], ] for value in invalid_values: - invalid_named_values = valid_named_values.copy() - invalid_named_values["1"] = value + invalid_configs = valid_configs.copy() + invalid_configs["1"] = value - state.current_stage = STAGE_SETUP + set_stage(STAGE_SETUP) with self.assertRaises(TypeError): - handler(invalid_named_values.copy()) + handler(invalid_configs.copy()) def test_stage_collect_masked_input_check(self) -> None: """Test content checking for the collect masked input stage.""" - state = SecAggPlusState() - handler = get_test_handler(state) + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { KEY_CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], KEY_SOURCE_LIST: [32, 51324, 32324123, -3], } # Test valid `named_values` try: - check_configs(STAGE_COLLECT_MASKED_INPUT, valid_named_values.copy()) + check_configs(STAGE_COLLECT_MASKED_INPUT, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_COLLECT_MASKED_INPUT + valid_configs[KEY_STAGE] = STAGE_COLLECT_MASKED_INPUT # Test invalid `named_values` # Test missing keys - for key in list(valid_named_values.keys()): + for key in list(valid_configs.keys()): if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) - state.current_stage = STAGE_SHARE_KEYS + set_stage(STAGE_SHARE_KEYS) with self.assertRaises(KeyError): - handler(invalid_named_values) + handler(invalid_configs) # Test wrong value type for the key - for key in valid_named_values: + for key in valid_configs: if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - invalid_named_values[key] = [3.1415926] + invalid_configs = valid_configs.copy() + invalid_configs[key] = [3.1415926] - state.current_stage = STAGE_SHARE_KEYS + set_stage(STAGE_SHARE_KEYS) with self.assertRaises(TypeError): - handler(invalid_named_values) + handler(invalid_configs) def test_stage_unmask_check(self) -> None: """Test content checking for the unmasking stage.""" - state = SecAggPlusState() - handler = get_test_handler(state) + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { KEY_ACTIVE_SECURE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], KEY_DEAD_SECURE_ID_LIST: [32, 51324, 32324123, -3], } # Test valid `named_values` try: - check_configs(STAGE_UNMASK, valid_named_values.copy()) + check_configs(STAGE_UNMASK, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_UNMASK + valid_configs[KEY_STAGE] = STAGE_UNMASK # Test invalid `named_values` # Test missing keys - for key in list(valid_named_values.keys()): + for key in list(valid_configs.keys()): if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) - state.current_stage = STAGE_COLLECT_MASKED_INPUT + set_stage(STAGE_COLLECT_MASKED_INPUT) with self.assertRaises(KeyError): - handler(invalid_named_values) + handler(invalid_configs) # Test wrong value type for the key - for key in valid_named_values: + for key in valid_configs: if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - invalid_named_values[key] = [True, False, True, False] + invalid_configs = valid_configs.copy() + invalid_configs[key] = [True, False, True, False] - state.current_stage = STAGE_COLLECT_MASKED_INPUT + set_stage(STAGE_COLLECT_MASKED_INPUT) with self.assertRaises(TypeError): - handler(invalid_named_values) + handler(invalid_configs) From af85eb6bd97ef5b1f766b4e8134457c64c0b97fc Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 30 Jan 2024 17:12:39 +0000 Subject: [PATCH 4/9] update example --- examples/secaggplus-mt/client.py | 33 ++++++++++++++++------- examples/secaggplus-mt/workflows.py | 41 ++++++++++++++++++++--------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/examples/secaggplus-mt/client.py b/examples/secaggplus-mt/client.py index f0f1348ee378..a0ed9e7181c7 100644 --- a/examples/secaggplus-mt/client.py +++ b/examples/secaggplus-mt/client.py @@ -5,11 +5,11 @@ import flwr as fl from flwr.common import Status, FitIns, FitRes, Code from flwr.common.parameter import ndarrays_to_parameters -from flwr.client.secure_aggregation import SecAggPlusHandler +from flwr.client.middleware import secaggplus_middleware # Define Flower client with the SecAgg+ protocol -class FlowerClient(fl.client.Client, SecAggPlusHandler): +class FlowerClient(fl.client.Client): def fit(self, fit_ins: FitIns) -> FitRes: ret_vec = [np.ones(3)] ret = FitRes( @@ -19,17 +19,30 @@ def fit(self, fit_ins: FitIns) -> FitRes: metrics={}, ) # Force a significant delay for testing purposes - if self._shared_state.sid == 0: - print(f"Client {self._shared_state.sid} dropped for testing purposes.") + if fit_ins.config["drop"]: + print(f"Client dropped for testing purposes.") time.sleep(4) return ret - print(f"Client {self._shared_state.sid} uploading {ret_vec[0]}...") + print(f"Client uploading {ret_vec[0]}...") return ret -# Start Flower client -fl.client.start_client( - server_address="0.0.0.0:9092", - client=FlowerClient(), - transport="grpc-rere", +def client_fn(cid: str): + """.""" + return FlowerClient().to_client() + + +# To run this: `flower-client --callable client:flower` +flower = fl.flower.Flower( + client_fn=client_fn, + layers=[secaggplus_middleware], ) + + +if __name__ == "__main__": + # Start Flower client + fl.client.start_client( + server_address="0.0.0.0:9092", + client=FlowerClient(), + transport="grpc-rere", + ) diff --git a/examples/secaggplus-mt/workflows.py b/examples/secaggplus-mt/workflows.py index 3117e308a498..b98de883b8f7 100644 --- a/examples/secaggplus-mt/workflows.py +++ b/examples/secaggplus-mt/workflows.py @@ -1,6 +1,6 @@ import random from logging import WARNING -from typing import Callable, Dict, Generator, List +from typing import Callable, Dict, Generator, List, Optional import numpy as np @@ -36,7 +36,6 @@ KEY_DESTINATION_LIST, KEY_MASKED_PARAMETERS, KEY_MOD_RANGE, - KEY_PARAMETERS, KEY_PUBLIC_KEY_1, KEY_PUBLIC_KEY_2, KEY_SAMPLE_NUMBER, @@ -52,11 +51,16 @@ STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_UNMASK, + RECORD_KEY_CONFIGS, ) from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen -from flwr.common.serde import named_values_from_proto, named_values_to_proto -from flwr.common.typing import Value -from flwr.proto.task_pb2 import SecureAggregation, Task +from flwr.common.typing import ConfigsRecordValues, FitIns, ServerMessage +from flwr.proto.task_pb2 import Task +from flwr.common import serde +from flwr.common.constant import TASK_TYPE_FIT +from flwr.common.recordset import RecordSet +from flwr.common import recordset_compat as compat +from flwr.common.configsrecord import ConfigsRecord LOG_EXPLAIN = True @@ -68,12 +72,23 @@ def get_workflow_factory() -> ( return _wrap_workflow_with_sec_agg -def _wrap_in_task(named_values: Dict[str, Value]) -> Task: - return Task(sa=SecureAggregation(named_values=named_values_to_proto(named_values))) +def _wrap_in_task( + named_values: Dict[str, ConfigsRecordValues], fit_ins: Optional[FitIns] = None +) -> Task: + if fit_ins is not None: + recordset = compat.fitins_to_recordset(fit_ins, keep_input=True) + else: + recordset = RecordSet() + recordset.set_configs(RECORD_KEY_CONFIGS, ConfigsRecord(named_values)) + return Task( + task_type=TASK_TYPE_FIT, + recordset=serde.recordset_to_proto(recordset), + ) -def _get_from_task(task: Task) -> Dict[str, Value]: - return named_values_from_proto(task.sa.named_values) +def _get_from_task(task: Task) -> Dict[str, ConfigsRecordValues]: + recordset = serde.recordset_from_proto(task.recordset) + return recordset.get_configs(RECORD_KEY_CONFIGS).data _secure_aggregation_configuration = { @@ -233,15 +248,16 @@ def workflow_with_sec_agg( if LOG_EXPLAIN: print(f"\nForwarding encrypted key shares and requesting masked input...") # Send encrypted secret key shares to clients (plus model parameters) - weights = parameters_to_ndarrays(parameters) yield { node_id: _wrap_in_task( named_values={ KEY_STAGE: STAGE_COLLECT_MASKED_INPUT, KEY_CIPHERTEXT_LIST: fwd_ciphertexts[nid2sid[node_id]], KEY_SOURCE_LIST: fwd_srcs[nid2sid[node_id]], - KEY_PARAMETERS: [ndarray_to_bytes(arr) for arr in weights], - } + }, + fit_ins=FitIns( + parameters=parameters, config={"drop": nid2sid[node_id] == 0} + ), ) for node_id in surviving_node_ids } @@ -249,6 +265,7 @@ def workflow_with_sec_agg( node_messages = yield surviving_node_ids = [node_id for node_id in node_messages] # Get shape of vector sent by first client + weights = parameters_to_ndarrays(parameters) masked_vector = [np.array([0], dtype=int)] + get_zero_parameters( [w.shape for w in weights] ) From dd2f38dd0d0942236d6d02174a0421db012c348b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 30 Jan 2024 17:28:20 +0000 Subject: [PATCH 5/9] fix a bug when converting state to dict --- .../middleware/secure_aggregation/secaggplus_middleware.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index e6376abb90d8..ce60bd928e20 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -119,8 +119,9 @@ def __init__(self, **kwargs: ConfigsRecordValues) -> None: continue new_v: Any = v if k.endswith(":K"): + k = k[:-2] keys = cast(List[int], v) - values = cast(List[bytes], kwargs[k[:-2] + ":V"]) + values = cast(List[bytes], kwargs[f"{k}:V"]) if len(values) > len(keys): updated_values = [ tuple(values[i : i + 2]) for i in range(0, len(values), 2) @@ -133,7 +134,7 @@ def __init__(self, **kwargs: ConfigsRecordValues) -> None: def to_dict(self) -> Dict[str, ConfigsRecordValues]: """Convert the state to a dictionary.""" ret = vars(self) - for k in ret.keys(): + for k in list(ret.keys()): if isinstance(ret[k], dict): # Replace dict with two lists v = cast(Dict[str, Any], ret.pop(k)) @@ -202,7 +203,7 @@ def secaggplus_middleware( raise ValueError(f"Unknown secagg stage: {state.current_stage}") # Save state - ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(vars(state))) + ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict())) # Return message return Message( From fbda0b91a315a2823d4330f03f6e9cf5d3bd37c1 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 30 Jan 2024 17:28:57 +0000 Subject: [PATCH 6/9] update driver --- examples/secaggplus-mt/driver.py | 5 +---- examples/secaggplus-mt/run.sh | 17 ++++++++--------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index f5871f1b44e4..d0d9a75f1b76 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -24,7 +24,6 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: task_id="", # Do not set, will be created and set by the DriverAPI group_id="", run_id=run_id, - run_id=run_id, task=merge( task, task_pb2.Task( @@ -193,9 +192,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: break # Collect correct results - node_messages = task_res_list_to_task_dict( - [res for res in all_task_res if res.task.HasField("sa")] - ) + node_messages = task_res_list_to_task_dict(all_task_res) workflow.close() # Slow down the start of the next round diff --git a/examples/secaggplus-mt/run.sh b/examples/secaggplus-mt/run.sh index 5cc769f6cbd8..852798c0ab21 100755 --- a/examples/secaggplus-mt/run.sh +++ b/examples/secaggplus-mt/run.sh @@ -1,13 +1,13 @@ #!/bin/bash # Kill any currently running client.py processes -pkill -f 'python client.py' +pkill -f 'flower-client' -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' +# Kill any currently running flower-server processes +pkill -f 'flower-server' # Start the flower server echo "Starting flower server in background..." -flower-server --grpc-rere > /dev/null 2>&1 & +flower-server --insecure > /dev/null 2>&1 & sleep 2 # Number of client processes to start @@ -18,8 +18,7 @@ echo "Starting $N clients in background..." # Start N client processes for i in $(seq 1 $N) do - python client.py > /dev/null 2>&1 & - # python client.py & + flower-client --insecure client:flower > /dev/null 2>&1 & sleep 0.1 done @@ -29,7 +28,7 @@ python driver.py echo "Clearing background processes..." # Kill any currently running client.py processes -pkill -f 'python client.py' +pkill -f 'flower-client' -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' +# Kill any currently running flower-server processes +pkill -f 'flower-server' \ No newline at end of file From 75d84b8d3db7f0966c399e3af4f0cb75405837b7 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 31 Jan 2024 14:15:42 +0100 Subject: [PATCH 7/9] Update src/py/flwr/client/middleware/secure_aggregation/__init__.py --- src/py/flwr/client/middleware/secure_aggregation/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/middleware/secure_aggregation/__init__.py b/src/py/flwr/client/middleware/secure_aggregation/__init__.py index b9da30e42b34..353828b02517 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/__init__.py +++ b/src/py/flwr/client/middleware/secure_aggregation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 96600ebc78fda7ea73ed0e96c6e92b84396096b6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 31 Jan 2024 13:59:46 +0000 Subject: [PATCH 8/9] reorder args --- .../middleware/secure_aggregation/secaggplus_middleware.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index ce60bd928e20..97fa89426d6b 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -150,7 +150,7 @@ def to_dict(self) -> Dict[str, ConfigsRecordValues]: def _get_fit_fn( - msg: Message, call_next: FlowerCallable, ctxt: Context + msg: Message, ctxt: Context, call_next: FlowerCallable ) -> Callable[[], FitRes]: """Get the fit function.""" @@ -195,7 +195,7 @@ def secaggplus_middleware( elif state.current_stage == STAGE_SHARE_KEYS: res = _share_keys(state, configs) elif state.current_stage == STAGE_COLLECT_MASKED_INPUT: - fit = _get_fit_fn(msg, call_next, ctxt) + fit = _get_fit_fn(msg, ctxt, call_next) res = _collect_masked_input(state, configs, fit) elif state.current_stage == STAGE_UNMASK: res = _unmask(state, configs) From 74f396692ed2f025cf7cffedd99355954953d1a5 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 1 Feb 2024 11:52:19 +0100 Subject: [PATCH 9/9] Apply suggestions from code review --- .../middleware/secure_aggregation/secaggplus_middleware.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index 97fa89426d6b..885dc4d9cbf5 100644 --- a/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -171,13 +171,13 @@ def secaggplus_middleware( if msg.metadata.task_type != TASK_TYPE_FIT: return call_next(msg, ctxt) - # Retrieve state + # Retrieve local state if RECORD_KEY_STATE not in ctxt.state.configs: ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord({})) state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data state = SecAggPlusState(**state_dict) - # Retrieve configs + # Retrieve incoming configs configs = msg.message.get_configs(RECORD_KEY_CONFIGS).data # Check the validity of the next stage