diff --git a/examples/secaggplus-mt/client.py b/examples/secaggplus-mt/client.py index f0f1348ee378..a0ed9e7181c7 100644 --- a/examples/secaggplus-mt/client.py +++ b/examples/secaggplus-mt/client.py @@ -5,11 +5,11 @@ import flwr as fl from flwr.common import Status, FitIns, FitRes, Code from flwr.common.parameter import ndarrays_to_parameters -from flwr.client.secure_aggregation import SecAggPlusHandler +from flwr.client.middleware import secaggplus_middleware # Define Flower client with the SecAgg+ protocol -class FlowerClient(fl.client.Client, SecAggPlusHandler): +class FlowerClient(fl.client.Client): def fit(self, fit_ins: FitIns) -> FitRes: ret_vec = [np.ones(3)] ret = FitRes( @@ -19,17 +19,30 @@ def fit(self, fit_ins: FitIns) -> FitRes: metrics={}, ) # Force a significant delay for testing purposes - if self._shared_state.sid == 0: - print(f"Client {self._shared_state.sid} dropped for testing purposes.") + if fit_ins.config["drop"]: + print(f"Client dropped for testing purposes.") time.sleep(4) return ret - print(f"Client {self._shared_state.sid} uploading {ret_vec[0]}...") + print(f"Client uploading {ret_vec[0]}...") return ret -# Start Flower client -fl.client.start_client( - server_address="0.0.0.0:9092", - client=FlowerClient(), - transport="grpc-rere", +def client_fn(cid: str): + """.""" + return FlowerClient().to_client() + + +# To run this: `flower-client --callable client:flower` +flower = fl.flower.Flower( + client_fn=client_fn, + layers=[secaggplus_middleware], ) + + +if __name__ == "__main__": + # Start Flower client + fl.client.start_client( + server_address="0.0.0.0:9092", + client=FlowerClient(), + transport="grpc-rere", + ) diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index f5871f1b44e4..d0d9a75f1b76 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -24,7 +24,6 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: task_id="", # Do not set, will be created and set by the DriverAPI group_id="", run_id=run_id, - run_id=run_id, task=merge( task, task_pb2.Task( @@ -193,9 +192,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: break # Collect correct results - node_messages = task_res_list_to_task_dict( - [res for res in all_task_res if res.task.HasField("sa")] - ) + node_messages = task_res_list_to_task_dict(all_task_res) workflow.close() # Slow down the start of the next round diff --git a/examples/secaggplus-mt/run.sh b/examples/secaggplus-mt/run.sh index 5cc769f6cbd8..852798c0ab21 100755 --- a/examples/secaggplus-mt/run.sh +++ b/examples/secaggplus-mt/run.sh @@ -1,13 +1,13 @@ #!/bin/bash # Kill any currently running client.py processes -pkill -f 'python client.py' +pkill -f 'flower-client' -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' +# Kill any currently running flower-server processes +pkill -f 'flower-server' # Start the flower server echo "Starting flower server in background..." -flower-server --grpc-rere > /dev/null 2>&1 & +flower-server --insecure > /dev/null 2>&1 & sleep 2 # Number of client processes to start @@ -18,8 +18,7 @@ echo "Starting $N clients in background..." # Start N client processes for i in $(seq 1 $N) do - python client.py > /dev/null 2>&1 & - # python client.py & + flower-client --insecure client:flower > /dev/null 2>&1 & sleep 0.1 done @@ -29,7 +28,7 @@ python driver.py echo "Clearing background processes..." # Kill any currently running client.py processes -pkill -f 'python client.py' +pkill -f 'flower-client' -# Kill any currently running flower-server processes with --grpc-rere option -pkill -f 'flower-server --grpc-rere' +# Kill any currently running flower-server processes +pkill -f 'flower-server' \ No newline at end of file diff --git a/examples/secaggplus-mt/workflows.py b/examples/secaggplus-mt/workflows.py index 3117e308a498..b98de883b8f7 100644 --- a/examples/secaggplus-mt/workflows.py +++ b/examples/secaggplus-mt/workflows.py @@ -1,6 +1,6 @@ import random from logging import WARNING -from typing import Callable, Dict, Generator, List +from typing import Callable, Dict, Generator, List, Optional import numpy as np @@ -36,7 +36,6 @@ KEY_DESTINATION_LIST, KEY_MASKED_PARAMETERS, KEY_MOD_RANGE, - KEY_PARAMETERS, KEY_PUBLIC_KEY_1, KEY_PUBLIC_KEY_2, KEY_SAMPLE_NUMBER, @@ -52,11 +51,16 @@ STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_UNMASK, + RECORD_KEY_CONFIGS, ) from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen -from flwr.common.serde import named_values_from_proto, named_values_to_proto -from flwr.common.typing import Value -from flwr.proto.task_pb2 import SecureAggregation, Task +from flwr.common.typing import ConfigsRecordValues, FitIns, ServerMessage +from flwr.proto.task_pb2 import Task +from flwr.common import serde +from flwr.common.constant import TASK_TYPE_FIT +from flwr.common.recordset import RecordSet +from flwr.common import recordset_compat as compat +from flwr.common.configsrecord import ConfigsRecord LOG_EXPLAIN = True @@ -68,12 +72,23 @@ def get_workflow_factory() -> ( return _wrap_workflow_with_sec_agg -def _wrap_in_task(named_values: Dict[str, Value]) -> Task: - return Task(sa=SecureAggregation(named_values=named_values_to_proto(named_values))) +def _wrap_in_task( + named_values: Dict[str, ConfigsRecordValues], fit_ins: Optional[FitIns] = None +) -> Task: + if fit_ins is not None: + recordset = compat.fitins_to_recordset(fit_ins, keep_input=True) + else: + recordset = RecordSet() + recordset.set_configs(RECORD_KEY_CONFIGS, ConfigsRecord(named_values)) + return Task( + task_type=TASK_TYPE_FIT, + recordset=serde.recordset_to_proto(recordset), + ) -def _get_from_task(task: Task) -> Dict[str, Value]: - return named_values_from_proto(task.sa.named_values) +def _get_from_task(task: Task) -> Dict[str, ConfigsRecordValues]: + recordset = serde.recordset_from_proto(task.recordset) + return recordset.get_configs(RECORD_KEY_CONFIGS).data _secure_aggregation_configuration = { @@ -233,15 +248,16 @@ def workflow_with_sec_agg( if LOG_EXPLAIN: print(f"\nForwarding encrypted key shares and requesting masked input...") # Send encrypted secret key shares to clients (plus model parameters) - weights = parameters_to_ndarrays(parameters) yield { node_id: _wrap_in_task( named_values={ KEY_STAGE: STAGE_COLLECT_MASKED_INPUT, KEY_CIPHERTEXT_LIST: fwd_ciphertexts[nid2sid[node_id]], KEY_SOURCE_LIST: fwd_srcs[nid2sid[node_id]], - KEY_PARAMETERS: [ndarray_to_bytes(arr) for arr in weights], - } + }, + fit_ins=FitIns( + parameters=parameters, config={"drop": nid2sid[node_id] == 0} + ), ) for node_id in surviving_node_ids } @@ -249,6 +265,7 @@ def workflow_with_sec_agg( node_messages = yield surviving_node_ids = [node_id for node_id in node_messages] # Get shape of vector sent by first client + weights = parameters_to_ndarrays(parameters) masked_vector = [np.array([0], dtype=int)] + get_zero_parameters( [w.shape for w in weights] ) diff --git a/src/py/flwr/client/middleware/__init__.py b/src/py/flwr/client/middleware/__init__.py index 58b31296fbbe..2cab6e61899c 100644 --- a/src/py/flwr/client/middleware/__init__.py +++ b/src/py/flwr/client/middleware/__init__.py @@ -15,8 +15,10 @@ """Middleware layers.""" +from .secure_aggregation.secaggplus_middleware import secaggplus_middleware from .utils import make_ffn __all__ = [ "make_ffn", + "secaggplus_middleware", ] diff --git a/src/py/flwr/client/secure_aggregation/__init__.py b/src/py/flwr/client/middleware/secure_aggregation/__init__.py similarity index 76% rename from src/py/flwr/client/secure_aggregation/__init__.py rename to src/py/flwr/client/middleware/secure_aggregation/__init__.py index 37c816a390de..353828b02517 100644 --- a/src/py/flwr/client/secure_aggregation/__init__.py +++ b/src/py/flwr/client/middleware/secure_aggregation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,8 @@ # limitations under the License. # ============================================================================== """Secure Aggregation handlers.""" - - -from .handler import SecureAggregationHandler -from .secaggplus_handler import SecAggPlusHandler +from .secaggplus_middleware import secaggplus_middleware __all__ = [ - "SecAggPlusHandler", - "SecureAggregationHandler", + "secaggplus_middleware", ] diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py similarity index 70% rename from src/py/flwr/client/secure_aggregation/secaggplus_handler.py rename to src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py index 4b74c1ace3de..885dc4d9cbf5 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,18 +17,18 @@ import os from dataclasses import dataclass, field -from logging import ERROR, INFO, WARNING -from typing import Any, Dict, List, Optional, Tuple, Union, cast - -from flwr.client.client import Client -from flwr.client.numpy_client import NumPyClient -from flwr.common import ( - bytes_to_ndarray, - ndarray_to_bytes, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from logging import INFO, WARNING +from typing import Any, Callable, Dict, List, Tuple, cast + +from flwr.client.typing import FlowerCallable +from flwr.common import ndarray_to_bytes, parameters_to_ndarrays +from flwr.common import recordset_compat as compat +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import TASK_TYPE_FIT +from flwr.common.context import Context from flwr.common.logger import log +from flwr.common.message import Message, Metadata +from flwr.common.recordset import RecordSet from flwr.common.secure_aggregation.crypto.shamir import create_shares from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_private_key, @@ -56,7 +56,6 @@ KEY_DESTINATION_LIST, KEY_MASKED_PARAMETERS, KEY_MOD_RANGE, - KEY_PARAMETERS, KEY_PUBLIC_KEY_1, KEY_PUBLIC_KEY_2, KEY_SAMPLE_NUMBER, @@ -68,6 +67,8 @@ KEY_STAGE, KEY_TARGET_RANGE, KEY_THRESHOLD, + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, STAGE_COLLECT_MASKED_INPUT, STAGE_SETUP, STAGE_SHARE_KEYS, @@ -79,9 +80,7 @@ share_keys_plaintext_concat, share_keys_plaintext_separate, ) -from flwr.common.typing import FitIns, Value - -from .handler import SecureAggregationHandler +from flwr.common.typing import ConfigsRecordValues, FitRes @dataclass @@ -89,6 +88,8 @@ class SecAggPlusState: """State of the SecAgg+ protocol.""" + current_stage: str = STAGE_UNMASK + sid: int = 0 sample_num: int = 0 share_num: int = 0 @@ -112,70 +113,115 @@ class SecAggPlusState: ss2_dict: Dict[int, bytes] = field(default_factory=dict) public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) - client: Optional[Union[Client, NumPyClient]] = None - - -class SecAggPlusHandler(SecureAggregationHandler): - """Message handler for the SecAgg+ protocol.""" - - _shared_state = SecAggPlusState() - _current_stage = STAGE_UNMASK - - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming message and return results, following the SecAgg+ protocol. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. + def __init__(self, **kwargs: ConfigsRecordValues) -> None: + for k, v in kwargs.items(): + if k.endswith(":V"): + continue + new_v: Any = v + if k.endswith(":K"): + k = k[:-2] + keys = cast(List[int], v) + values = cast(List[bytes], kwargs[f"{k}:V"]) + if len(values) > len(keys): + updated_values = [ + tuple(values[i : i + 2]) for i in range(0, len(values), 2) + ] + new_v = dict(zip(keys, updated_values)) + else: + new_v = dict(zip(keys, values)) + self.__setattr__(k, new_v) + + def to_dict(self) -> Dict[str, ConfigsRecordValues]: + """Convert the state to a dictionary.""" + ret = vars(self) + for k in list(ret.keys()): + if isinstance(ret[k], dict): + # Replace dict with two lists + v = cast(Dict[str, Any], ret.pop(k)) + ret[f"{k}:K"] = list(v.keys()) + if k == "public_keys_dict": + v_list: List[bytes] = [] + for b1_b2 in cast(List[Tuple[bytes, bytes]], v.values()): + v_list.extend(b1_b2) + ret[f"{k}:V"] = v_list + else: + ret[f"{k}:V"] = list(v.values()) + return ret + + +def _get_fit_fn( + msg: Message, ctxt: Context, call_next: FlowerCallable +) -> Callable[[], FitRes]: + """Get the fit function.""" + + def fit() -> FitRes: + out_msg = call_next(msg, ctxt) + return compat.recordset_to_fitres(out_msg.message, keep_input=False) + + return fit + + +def secaggplus_middleware( + msg: Message, + ctxt: Context, + call_next: FlowerCallable, +) -> Message: + """Handle incoming message and return results, following the SecAgg+ protocol.""" + # Ignore non-fit messages + if msg.metadata.task_type != TASK_TYPE_FIT: + return call_next(msg, ctxt) + + # Retrieve local state + if RECORD_KEY_STATE not in ctxt.state.configs: + ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord({})) + state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data + state = SecAggPlusState(**state_dict) + + # Retrieve incoming configs + configs = msg.message.get_configs(RECORD_KEY_CONFIGS).data - Returns - ------- - Dict[str, Value] - The final/intermediate results of the SecAgg+ protocol. - """ - # Check if self is a client - if not isinstance(self, (Client, NumPyClient)): - raise TypeError( - "The subclass of SecAggPlusHandler must be " - "the subclass of Client or NumPyClient." - ) - - # Check the validity of the next stage - check_stage(self._current_stage, named_values) - - # Update the current stage - self._current_stage = cast(str, named_values.pop(KEY_STAGE)) + # Check the validity of the next stage + check_stage(state.current_stage, configs) + + # Update the current stage + state.current_stage = cast(str, configs.pop(KEY_STAGE)) + + # Check the validity of the configs based on the current stage + check_configs(state.current_stage, configs) + + # Execute + if state.current_stage == STAGE_SETUP: + res = _setup(state, configs) + elif state.current_stage == STAGE_SHARE_KEYS: + res = _share_keys(state, configs) + elif state.current_stage == STAGE_COLLECT_MASKED_INPUT: + fit = _get_fit_fn(msg, ctxt, call_next) + res = _collect_masked_input(state, configs, fit) + elif state.current_stage == STAGE_UNMASK: + res = _unmask(state, configs) + else: + raise ValueError(f"Unknown secagg stage: {state.current_stage}") - # Check the validity of the `named_values` based on the current stage - check_named_values(self._current_stage, named_values) + # Save state + ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict())) - # Execute - if self._current_stage == STAGE_SETUP: - self._shared_state = SecAggPlusState(client=self) - return _setup(self._shared_state, named_values) - if self._current_stage == STAGE_SHARE_KEYS: - return _share_keys(self._shared_state, named_values) - if self._current_stage == STAGE_COLLECT_MASKED_INPUT: - return _collect_masked_input(self._shared_state, named_values) - if self._current_stage == STAGE_UNMASK: - return _unmask(self._shared_state, named_values) - raise ValueError(f"Unknown secagg stage: {self._current_stage}") + # Return message + return Message( + metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}), + ) -def check_stage(current_stage: str, named_values: Dict[str, Value]) -> None: +def check_stage(current_stage: str, configs: Dict[str, ConfigsRecordValues]) -> None: """Check the validity of the next stage.""" # Check the existence of KEY_STAGE - if KEY_STAGE not in named_values: + if KEY_STAGE not in configs: raise KeyError( f"The required key '{KEY_STAGE}' is missing from the input `named_values`." ) # Check the value type of the KEY_STAGE - next_stage = named_values[KEY_STAGE] + next_stage = configs[KEY_STAGE] if not isinstance(next_stage, str): raise TypeError( f"The value for the key '{KEY_STAGE}' must be of type {str}, " @@ -198,8 +244,8 @@ def check_stage(current_stage: str, named_values: Dict[str, Value]) -> None: # pylint: disable-next=too-many-branches -def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: - """Check the validity of the input `named_values`.""" +def check_configs(stage: str, configs: Dict[str, ConfigsRecordValues]) -> None: + """Check the validity of the configs.""" # Check `named_values` for the setup stage if stage == STAGE_SETUP: key_type_pairs = [ @@ -212,7 +258,7 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: (KEY_MOD_RANGE, int), ] for key, expected_type in key_type_pairs: - if key not in named_values: + if key not in configs: raise KeyError( f"Stage {STAGE_SETUP}: the required key '{key}' is " "missing from the input `named_values`." @@ -220,14 +266,14 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: # Bool is a subclass of int in Python, # so `isinstance(v, int)` will return True even if v is a boolean. # pylint: disable-next=unidiomatic-typecheck - if type(named_values[key]) is not expected_type: + if type(configs[key]) is not expected_type: raise TypeError( f"Stage {STAGE_SETUP}: The value for the key '{key}' " f"must be of type {expected_type}, " - f"but got {type(named_values[key])} instead." + f"but got {type(configs[key])} instead." ) elif stage == STAGE_SHARE_KEYS: - for key, value in named_values.items(): + for key, value in configs.items(): if ( not isinstance(value, list) or len(value) != 2 @@ -242,18 +288,17 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: key_type_pairs = [ (KEY_CIPHERTEXT_LIST, bytes), (KEY_SOURCE_LIST, int), - (KEY_PARAMETERS, bytes), ] for key, expected_type in key_type_pairs: - if key not in named_values: + if key not in configs: raise KeyError( f"Stage {STAGE_COLLECT_MASKED_INPUT}: " f"the required key '{key}' is " "missing from the input `named_values`." ) - if not isinstance(named_values[key], list) or any( + if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], named_values[key]) + for elm in cast(List[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -268,15 +313,15 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: (KEY_DEAD_SECURE_ID_LIST, int), ] for key, expected_type in key_type_pairs: - if key not in named_values: + if key not in configs: raise KeyError( f"Stage {STAGE_UNMASK}: " f"the required key '{key}' is " "missing from the input `named_values`." ) - if not isinstance(named_values[key], list) or any( + if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], named_values[key]) + for elm in cast(List[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -289,9 +334,11 @@ def check_named_values(stage: str, named_values: Dict[str, Value]) -> None: raise ValueError(f"Unknown secagg stage: {stage}") -def _setup(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: +def _setup( + state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] +) -> Dict[str, ConfigsRecordValues]: # Assigning parameter values to object fields - sec_agg_param_dict = named_values + sec_agg_param_dict = configs state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER]) state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID]) log(INFO, "Client %d: starting stage 0...", state.sid) @@ -324,9 +371,9 @@ def _setup(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, # pylint: disable-next=too-many-locals def _share_keys( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: - named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], named_values) + state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] +) -> Dict[str, ConfigsRecordValues]: + named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} log(INFO, "Client %d: starting stage 1...", state.sid) state.public_keys_dict = key_dict @@ -386,12 +433,14 @@ def _share_keys( # pylint: disable-next=too-many-locals def _collect_masked_input( - state: SecAggPlusState, named_values: Dict[str, Value] -) -> Dict[str, Value]: + state: SecAggPlusState, + configs: Dict[str, ConfigsRecordValues], + fit: Callable[[], FitRes], +) -> Dict[str, ConfigsRecordValues]: log(INFO, "Client %d: starting stage 2...", state.sid) available_clients: List[int] = [] - ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) - srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) + ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST]) + srcs = cast(List[int], configs[KEY_SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: raise ValueError("Not enough available neighbour clients.") @@ -417,18 +466,9 @@ def _collect_masked_input( state.sk1_share_dict[src] = sk1_share # Fit client - parameters_bytes = cast(List[bytes], named_values[KEY_PARAMETERS]) - parameters = [bytes_to_ndarray(w) for w in parameters_bytes] - if isinstance(state.client, Client): - fit_res = state.client.fit( - FitIns(parameters=ndarrays_to_parameters(parameters), config={}) - ) - parameters_factor = fit_res.num_examples - parameters = parameters_to_ndarrays(fit_res.parameters) - elif isinstance(state.client, NumPyClient): - parameters, parameters_factor, _ = state.client.fit(parameters, {}) - else: - log(ERROR, "Client %d: fit function is missing.", state.sid) + fit_res = fit() + parameters_factor = fit_res.num_examples + parameters = parameters_to_ndarrays(fit_res.parameters) # Quantize parameter update (vector) quantized_parameters = quantize( @@ -468,11 +508,13 @@ def _collect_masked_input( } -def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, Value]: +def _unmask( + state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues] +) -> Dict[str, ConfigsRecordValues]: log(INFO, "Client %d: starting stage 3...", state.sid) - active_sids = cast(List[int], named_values[KEY_ACTIVE_SECURE_ID_LIST]) - dead_sids = cast(List[int], named_values[KEY_DEAD_SECURE_ID_LIST]) + active_sids = cast(List[int], configs[KEY_ACTIVE_SECURE_ID_LIST]) + dead_sids = cast(List[int], configs[KEY_DEAD_SECURE_ID_LIST]) # Send private mask seed share for every avaliable client (including itclient) # Send first private key share for building pairwise mask for every dropped client if len(active_sids) < state.threshold: diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py similarity index 57% rename from src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py rename to src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py index 9693a46af989..8ec52d71cbdd 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler_test.py +++ b/src/py/flwr/client/middleware/secure_aggregation/secaggplus_middleware_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,16 +16,20 @@ import unittest from itertools import product -from typing import Any, Dict, List, cast - -from flwr.client import NumPyClient +from typing import Callable, Dict, List + +from flwr.client.middleware import make_ffn +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import TASK_TYPE_FIT +from flwr.common.context import Context +from flwr.common.message import Message, Metadata +from flwr.common.recordset import RecordSet from flwr.common.secure_aggregation.secaggplus_constants import ( KEY_ACTIVE_SECURE_ID_LIST, KEY_CIPHERTEXT_LIST, KEY_CLIPPING_RANGE, KEY_DEAD_SECURE_ID_LIST, KEY_MOD_RANGE, - KEY_PARAMETERS, KEY_SAMPLE_NUMBER, KEY_SECURE_ID, KEY_SHARE_NUMBER, @@ -33,34 +37,68 @@ KEY_STAGE, KEY_TARGET_RANGE, KEY_THRESHOLD, + RECORD_KEY_CONFIGS, + RECORD_KEY_STATE, STAGE_COLLECT_MASKED_INPUT, STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_UNMASK, STAGES, ) -from flwr.common.typing import Value +from flwr.common.typing import ConfigsRecordValues -from .secaggplus_handler import SecAggPlusHandler, check_named_values +from .secaggplus_middleware import SecAggPlusState, check_configs, secaggplus_middleware -class EmptyFlowerNumPyClient(NumPyClient, SecAggPlusHandler): - """Empty NumPyClient.""" +def get_test_handler( + ctxt: Context, +) -> Callable[[Dict[str, ConfigsRecordValues]], Dict[str, ConfigsRecordValues]]: + """.""" + def empty_ffn(_: Message, _2: Context) -> Message: + return Message( + metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + message=RecordSet(), + ) + + app = make_ffn(empty_ffn, [secaggplus_middleware]) + + def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]: + in_msg = Message( + metadata=Metadata(0, "", "", "", TASK_TYPE_FIT), + message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}), + ) + out_msg = app(in_msg, ctxt) + return out_msg.message.get_configs(RECORD_KEY_CONFIGS).data + + return func + + +def _make_ctxt() -> Context: + cfg = ConfigsRecord(SecAggPlusState().to_dict()) + return Context(RecordSet(configs={RECORD_KEY_STATE: cfg})) -class TestSecAggPlusHandler(unittest.TestCase): - """Test the SecAgg+ protocol handler.""" - def test_invalid_handler(self) -> None: - """Test invalid handler.""" - handler = SecAggPlusHandler() +def _make_set_state_fn( + ctxt: Context, +) -> Callable[[str], None]: + def set_stage(stage: str) -> None: + state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data + state = SecAggPlusState(**state_dict) + state.current_stage = stage + ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict())) - with self.assertRaises(TypeError): - handler.handle_secure_aggregation({}) + return set_stage + + +class TestSecAggPlusHandler(unittest.TestCase): + """Test the SecAgg+ protocol handler.""" def test_stage_transition(self) -> None: """Test stage transition.""" - handler = EmptyFlowerNumPyClient() + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) assert STAGES == ( STAGE_SETUP, @@ -88,28 +126,24 @@ def test_stage_transition(self) -> None: # If the next stage is valid, the function should update the current stage # and then raise KeyError or other exceptions when trying to execute SA. for current_stage, next_stage in valid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage + set_stage(current_stage) with self.assertRaises(KeyError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == next_stage + handler({KEY_STAGE: next_stage}) # Test invalid transitions # If the next stage is invalid, the function should raise ValueError for current_stage, next_stage in invalid_transitions: - # pylint: disable-next=protected-access - handler._current_stage = current_stage + set_stage(current_stage) with self.assertRaises(ValueError): - handler.handle_secure_aggregation({KEY_STAGE: next_stage}) - # pylint: disable-next=protected-access - assert handler._current_stage == current_stage + handler({KEY_STAGE: next_stage}) def test_stage_setup_check(self) -> None: """Test content checking for the setup stage.""" - handler = EmptyFlowerNumPyClient() + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) valid_key_type_pairs = [ (KEY_SAMPLE_NUMBER, int), @@ -121,7 +155,7 @@ def test_stage_setup_check(self) -> None: (KEY_MOD_RANGE, int), ] - type_to_test_value: Dict[type, Value] = { + type_to_test_value: Dict[type, ConfigsRecordValues] = { int: 10, bool: True, float: 1.0, @@ -129,47 +163,49 @@ def test_stage_setup_check(self) -> None: bytes: b"test", } - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { key: type_to_test_value[value_type] for key, value_type in valid_key_type_pairs } # Test valid `named_values` try: - check_named_values(STAGE_SETUP, valid_named_values.copy()) + check_configs(STAGE_SETUP, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SETUP + valid_configs[KEY_STAGE] = STAGE_SETUP # Test invalid `named_values` for key, value_type in valid_key_type_pairs: - invalid_named_values = valid_named_values.copy() + invalid_configs = valid_configs.copy() # Test wrong value type for the key for other_type, other_value in type_to_test_value.items(): if other_type == value_type: continue - invalid_named_values[key] = other_value - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK + invalid_configs[key] = other_value + + set_stage(STAGE_UNMASK) with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) + handler(invalid_configs.copy()) # Test missing key - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_UNMASK + invalid_configs.pop(key) + + set_stage(STAGE_UNMASK) with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values.copy()) + handler(invalid_configs.copy()) def test_stage_share_keys_check(self) -> None: """Test content checking for the share keys stage.""" - handler = EmptyFlowerNumPyClient() + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { "1": [b"public key 1", b"public key 2"], "2": [b"public key 1", b"public key 2"], "3": [b"public key 1", b"public key 2"], @@ -177,111 +213,113 @@ def test_stage_share_keys_check(self) -> None: # Test valid `named_values` try: - check_named_values(STAGE_SHARE_KEYS, valid_named_values.copy()) + check_configs(STAGE_SHARE_KEYS, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_SHARE_KEYS + valid_configs[KEY_STAGE] = STAGE_SHARE_KEYS # Test invalid `named_values` - invalid_values: List[Value] = [ + invalid_values: List[ConfigsRecordValues] = [ b"public key 1", [b"public key 1"], [b"public key 1", b"public key 2", b"public key 3"], ] for value in invalid_values: - invalid_named_values = valid_named_values.copy() - invalid_named_values["1"] = value + invalid_configs = valid_configs.copy() + invalid_configs["1"] = value - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SETUP + set_stage(STAGE_SETUP) with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values.copy()) + handler(invalid_configs.copy()) def test_stage_collect_masked_input_check(self) -> None: """Test content checking for the collect masked input stage.""" - handler = EmptyFlowerNumPyClient() + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { KEY_CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], KEY_SOURCE_LIST: [32, 51324, 32324123, -3], - KEY_PARAMETERS: [b"params1", b"params2"], } # Test valid `named_values` try: - check_named_values(STAGE_COLLECT_MASKED_INPUT, valid_named_values.copy()) + check_configs(STAGE_COLLECT_MASKED_INPUT, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_COLLECT_MASKED_INPUT + valid_configs[KEY_STAGE] = STAGE_COLLECT_MASKED_INPUT # Test invalid `named_values` # Test missing keys - for key in list(valid_named_values.keys()): + for key in list(valid_configs.keys()): if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) + + set_stage(STAGE_SHARE_KEYS) with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_configs) # Test wrong value type for the key - for key in valid_named_values: + for key in valid_configs: if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(3.1415926) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_SHARE_KEYS + invalid_configs = valid_configs.copy() + invalid_configs[key] = [3.1415926] + + set_stage(STAGE_SHARE_KEYS) with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_configs) def test_stage_unmask_check(self) -> None: """Test content checking for the unmasking stage.""" - handler = EmptyFlowerNumPyClient() + ctxt = _make_ctxt() + handler = get_test_handler(ctxt) + set_stage = _make_set_state_fn(ctxt) - valid_named_values: Dict[str, Value] = { + valid_configs: Dict[str, ConfigsRecordValues] = { KEY_ACTIVE_SECURE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], KEY_DEAD_SECURE_ID_LIST: [32, 51324, 32324123, -3], } # Test valid `named_values` try: - check_named_values(STAGE_UNMASK, valid_named_values.copy()) + check_configs(STAGE_UNMASK, valid_configs.copy()) # pylint: disable-next=broad-except except Exception as exc: self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") # Set the stage - valid_named_values[KEY_STAGE] = STAGE_UNMASK + valid_configs[KEY_STAGE] = STAGE_UNMASK # Test invalid `named_values` # Test missing keys - for key in list(valid_named_values.keys()): + for key in list(valid_configs.keys()): if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - invalid_named_values.pop(key) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT + invalid_configs = valid_configs.copy() + invalid_configs.pop(key) + + set_stage(STAGE_COLLECT_MASKED_INPUT) with self.assertRaises(KeyError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_configs) # Test wrong value type for the key - for key in valid_named_values: + for key in valid_configs: if key == KEY_STAGE: continue - invalid_named_values = valid_named_values.copy() - cast(List[Any], invalid_named_values[key]).append(True) - # pylint: disable-next=protected-access - handler._current_stage = STAGE_COLLECT_MASKED_INPUT + invalid_configs = valid_configs.copy() + invalid_configs[key] = [True, False, True, False] + + set_stage(STAGE_COLLECT_MASKED_INPUT) with self.assertRaises(TypeError): - handler.handle_secure_aggregation(invalid_named_values) + handler(invalid_configs) diff --git a/src/py/flwr/client/secure_aggregation/handler.py b/src/py/flwr/client/secure_aggregation/handler.py deleted file mode 100644 index 487ed842c93f..000000000000 --- a/src/py/flwr/client/secure_aggregation/handler.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Message Handler for Secure Aggregation (abstract base class).""" - - -from abc import ABC, abstractmethod -from typing import Dict - -from flwr.common.typing import Value - - -class SecureAggregationHandler(ABC): - """Abstract base class for secure aggregation message handlers.""" - - @abstractmethod - def handle_secure_aggregation( - self, named_values: Dict[str, Value] - ) -> Dict[str, Value]: - """Handle incoming Secure Aggregation message and return results. - - Parameters - ---------- - named_values : Dict[str, Value] - The named values retrieved from the SecureAggregation sub-message - of Task message in the server's TaskIns. - - Returns - ------- - Dict[str, Value] - The final/intermediate results of the Secure Aggregation protocol. - """ diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py index 8dd21a6016f1..9a2bf26e98e8 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py @@ -14,6 +14,8 @@ # ============================================================================== """Constants for the SecAgg/SecAgg+ protocol.""" +RECORD_KEY_STATE = "secaggplus_state" +RECORD_KEY_CONFIGS = "secaggplus_configs" # Names of stages STAGE_SETUP = "setup"