Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SecAggPlus middleware #2871

Merged
merged 16 commits into from
Feb 1, 2024
33 changes: 23 additions & 10 deletions examples/secaggplus-mt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import flwr as fl
from flwr.common import Status, FitIns, FitRes, Code
from flwr.common.parameter import ndarrays_to_parameters
from flwr.client.secure_aggregation import SecAggPlusHandler
from flwr.client.middleware import secaggplus_middleware


# Define Flower client with the SecAgg+ protocol
class FlowerClient(fl.client.Client, SecAggPlusHandler):
class FlowerClient(fl.client.Client):
def fit(self, fit_ins: FitIns) -> FitRes:
ret_vec = [np.ones(3)]
ret = FitRes(
Expand All @@ -19,17 +19,30 @@ def fit(self, fit_ins: FitIns) -> FitRes:
metrics={},
)
# Force a significant delay for testing purposes
if self._shared_state.sid == 0:
print(f"Client {self._shared_state.sid} dropped for testing purposes.")
if fit_ins.config["drop"]:
print(f"Client dropped for testing purposes.")
time.sleep(4)
return ret
print(f"Client {self._shared_state.sid} uploading {ret_vec[0]}...")
print(f"Client uploading {ret_vec[0]}...")
return ret


# Start Flower client
fl.client.start_client(
server_address="0.0.0.0:9092",
client=FlowerClient(),
transport="grpc-rere",
def client_fn(cid: str):
"""."""
return FlowerClient().to_client()


# To run this: `flower-client --callable client:flower`
flower = fl.flower.Flower(
client_fn=client_fn,
layers=[secaggplus_middleware],
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="0.0.0.0:9092",
client=FlowerClient(),
transport="grpc-rere",
)
5 changes: 1 addition & 4 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task:
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
run_id=run_id,
run_id=run_id,
task=merge(
task,
task_pb2.Task(
Expand Down Expand Up @@ -193,9 +192,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
break

# Collect correct results
node_messages = task_res_list_to_task_dict(
[res for res in all_task_res if res.task.HasField("sa")]
)
node_messages = task_res_list_to_task_dict(all_task_res)
workflow.close()

# Slow down the start of the next round
Expand Down
17 changes: 8 additions & 9 deletions examples/secaggplus-mt/run.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/bin/bash
# Kill any currently running client.py processes
pkill -f 'python client.py'
pkill -f 'flower-client'

# Kill any currently running flower-server processes with --grpc-rere option
pkill -f 'flower-server --grpc-rere'
# Kill any currently running flower-server processes
pkill -f 'flower-server'

# Start the flower server
echo "Starting flower server in background..."
flower-server --grpc-rere > /dev/null 2>&1 &
flower-server --insecure > /dev/null 2>&1 &
sleep 2

# Number of client processes to start
Expand All @@ -18,8 +18,7 @@ echo "Starting $N clients in background..."
# Start N client processes
for i in $(seq 1 $N)
do
python client.py > /dev/null 2>&1 &
# python client.py &
flower-client --insecure client:flower > /dev/null 2>&1 &
sleep 0.1
done

Expand All @@ -29,7 +28,7 @@ python driver.py
echo "Clearing background processes..."

# Kill any currently running client.py processes
pkill -f 'python client.py'
pkill -f 'flower-client'

# Kill any currently running flower-server processes with --grpc-rere option
pkill -f 'flower-server --grpc-rere'
# Kill any currently running flower-server processes
pkill -f 'flower-server'
41 changes: 29 additions & 12 deletions examples/secaggplus-mt/workflows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from logging import WARNING
from typing import Callable, Dict, Generator, List
from typing import Callable, Dict, Generator, List, Optional

import numpy as np

Expand Down Expand Up @@ -36,7 +36,6 @@
KEY_DESTINATION_LIST,
KEY_MASKED_PARAMETERS,
KEY_MOD_RANGE,
KEY_PARAMETERS,
KEY_PUBLIC_KEY_1,
KEY_PUBLIC_KEY_2,
KEY_SAMPLE_NUMBER,
Expand All @@ -52,11 +51,16 @@
STAGE_SETUP,
STAGE_SHARE_KEYS,
STAGE_UNMASK,
RECORD_KEY_CONFIGS,
)
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
from flwr.common.serde import named_values_from_proto, named_values_to_proto
from flwr.common.typing import Value
from flwr.proto.task_pb2 import SecureAggregation, Task
from flwr.common.typing import ConfigsRecordValues, FitIns, ServerMessage
from flwr.proto.task_pb2 import Task
from flwr.common import serde
from flwr.common.constant import TASK_TYPE_FIT
from flwr.common.recordset import RecordSet
from flwr.common import recordset_compat as compat
from flwr.common.configsrecord import ConfigsRecord


LOG_EXPLAIN = True
Expand All @@ -68,12 +72,23 @@ def get_workflow_factory() -> (
return _wrap_workflow_with_sec_agg


def _wrap_in_task(named_values: Dict[str, Value]) -> Task:
return Task(sa=SecureAggregation(named_values=named_values_to_proto(named_values)))
def _wrap_in_task(
named_values: Dict[str, ConfigsRecordValues], fit_ins: Optional[FitIns] = None
) -> Task:
if fit_ins is not None:
recordset = compat.fitins_to_recordset(fit_ins, keep_input=True)
else:
recordset = RecordSet()
recordset.set_configs(RECORD_KEY_CONFIGS, ConfigsRecord(named_values))
return Task(
task_type=TASK_TYPE_FIT,
recordset=serde.recordset_to_proto(recordset),
)


def _get_from_task(task: Task) -> Dict[str, Value]:
return named_values_from_proto(task.sa.named_values)
def _get_from_task(task: Task) -> Dict[str, ConfigsRecordValues]:
recordset = serde.recordset_from_proto(task.recordset)
return recordset.get_configs(RECORD_KEY_CONFIGS).data


_secure_aggregation_configuration = {
Expand Down Expand Up @@ -233,22 +248,24 @@ def workflow_with_sec_agg(
if LOG_EXPLAIN:
print(f"\nForwarding encrypted key shares and requesting masked input...")
# Send encrypted secret key shares to clients (plus model parameters)
weights = parameters_to_ndarrays(parameters)
yield {
node_id: _wrap_in_task(
named_values={
KEY_STAGE: STAGE_COLLECT_MASKED_INPUT,
KEY_CIPHERTEXT_LIST: fwd_ciphertexts[nid2sid[node_id]],
KEY_SOURCE_LIST: fwd_srcs[nid2sid[node_id]],
KEY_PARAMETERS: [ndarray_to_bytes(arr) for arr in weights],
}
},
fit_ins=FitIns(
parameters=parameters, config={"drop": nid2sid[node_id] == 0}
),
)
for node_id in surviving_node_ids
}
# Collect masked input from clients
node_messages = yield
surviving_node_ids = [node_id for node_id in node_messages]
# Get shape of vector sent by first client
weights = parameters_to_ndarrays(parameters)
masked_vector = [np.array([0], dtype=int)] + get_zero_parameters(
[w.shape for w in weights]
)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/client/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Middleware layers."""


from .secure_aggregation.secaggplus_middleware import secaggplus_middleware
from .utils import make_ffn

__all__ = [
"make_ffn",
"secaggplus_middleware",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,12 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Secure Aggregation handlers."""


from .handler import SecureAggregationHandler
from .secaggplus_handler import SecAggPlusHandler
from .secaggplus_middleware import secaggplus_middleware

__all__ = [
"SecAggPlusHandler",
"SecureAggregationHandler",
"secaggplus_middleware",
]
Loading
Loading