Skip to content

Commit

Permalink
Add SecAggPlus layer (#2871)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Feb 1, 2024
1 parent ef9c412 commit 55e9004
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 274 deletions.
33 changes: 23 additions & 10 deletions examples/secaggplus-mt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import flwr as fl
from flwr.common import Status, FitIns, FitRes, Code
from flwr.common.parameter import ndarrays_to_parameters
from flwr.client.secure_aggregation import SecAggPlusHandler
from flwr.client.middleware import secaggplus_middleware


# Define Flower client with the SecAgg+ protocol
class FlowerClient(fl.client.Client, SecAggPlusHandler):
class FlowerClient(fl.client.Client):
def fit(self, fit_ins: FitIns) -> FitRes:
ret_vec = [np.ones(3)]
ret = FitRes(
Expand All @@ -19,17 +19,30 @@ def fit(self, fit_ins: FitIns) -> FitRes:
metrics={},
)
# Force a significant delay for testing purposes
if self._shared_state.sid == 0:
print(f"Client {self._shared_state.sid} dropped for testing purposes.")
if fit_ins.config["drop"]:
print(f"Client dropped for testing purposes.")
time.sleep(4)
return ret
print(f"Client {self._shared_state.sid} uploading {ret_vec[0]}...")
print(f"Client uploading {ret_vec[0]}...")
return ret


# Start Flower client
fl.client.start_client(
server_address="0.0.0.0:9092",
client=FlowerClient(),
transport="grpc-rere",
def client_fn(cid: str):
"""."""
return FlowerClient().to_client()


# To run this: `flower-client --callable client:flower`
flower = fl.flower.Flower(
client_fn=client_fn,
layers=[secaggplus_middleware],
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="0.0.0.0:9092",
client=FlowerClient(),
transport="grpc-rere",
)
5 changes: 1 addition & 4 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task:
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
run_id=run_id,
run_id=run_id,
task=merge(
task,
task_pb2.Task(
Expand Down Expand Up @@ -193,9 +192,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
break

# Collect correct results
node_messages = task_res_list_to_task_dict(
[res for res in all_task_res if res.task.HasField("sa")]
)
node_messages = task_res_list_to_task_dict(all_task_res)
workflow.close()

# Slow down the start of the next round
Expand Down
17 changes: 8 additions & 9 deletions examples/secaggplus-mt/run.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/bin/bash
# Kill any currently running client.py processes
pkill -f 'python client.py'
pkill -f 'flower-client'

# Kill any currently running flower-server processes with --grpc-rere option
pkill -f 'flower-server --grpc-rere'
# Kill any currently running flower-server processes
pkill -f 'flower-server'

# Start the flower server
echo "Starting flower server in background..."
flower-server --grpc-rere > /dev/null 2>&1 &
flower-server --insecure > /dev/null 2>&1 &
sleep 2

# Number of client processes to start
Expand All @@ -18,8 +18,7 @@ echo "Starting $N clients in background..."
# Start N client processes
for i in $(seq 1 $N)
do
python client.py > /dev/null 2>&1 &
# python client.py &
flower-client --insecure client:flower > /dev/null 2>&1 &
sleep 0.1
done

Expand All @@ -29,7 +28,7 @@ python driver.py
echo "Clearing background processes..."

# Kill any currently running client.py processes
pkill -f 'python client.py'
pkill -f 'flower-client'

# Kill any currently running flower-server processes with --grpc-rere option
pkill -f 'flower-server --grpc-rere'
# Kill any currently running flower-server processes
pkill -f 'flower-server'
41 changes: 29 additions & 12 deletions examples/secaggplus-mt/workflows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from logging import WARNING
from typing import Callable, Dict, Generator, List
from typing import Callable, Dict, Generator, List, Optional

import numpy as np

Expand Down Expand Up @@ -36,7 +36,6 @@
KEY_DESTINATION_LIST,
KEY_MASKED_PARAMETERS,
KEY_MOD_RANGE,
KEY_PARAMETERS,
KEY_PUBLIC_KEY_1,
KEY_PUBLIC_KEY_2,
KEY_SAMPLE_NUMBER,
Expand All @@ -52,11 +51,16 @@
STAGE_SETUP,
STAGE_SHARE_KEYS,
STAGE_UNMASK,
RECORD_KEY_CONFIGS,
)
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
from flwr.common.serde import named_values_from_proto, named_values_to_proto
from flwr.common.typing import Value
from flwr.proto.task_pb2 import SecureAggregation, Task
from flwr.common.typing import ConfigsRecordValues, FitIns, ServerMessage
from flwr.proto.task_pb2 import Task
from flwr.common import serde
from flwr.common.constant import TASK_TYPE_FIT
from flwr.common.recordset import RecordSet
from flwr.common import recordset_compat as compat
from flwr.common.configsrecord import ConfigsRecord


LOG_EXPLAIN = True
Expand All @@ -68,12 +72,23 @@ def get_workflow_factory() -> (
return _wrap_workflow_with_sec_agg


def _wrap_in_task(named_values: Dict[str, Value]) -> Task:
return Task(sa=SecureAggregation(named_values=named_values_to_proto(named_values)))
def _wrap_in_task(
named_values: Dict[str, ConfigsRecordValues], fit_ins: Optional[FitIns] = None
) -> Task:
if fit_ins is not None:
recordset = compat.fitins_to_recordset(fit_ins, keep_input=True)
else:
recordset = RecordSet()
recordset.set_configs(RECORD_KEY_CONFIGS, ConfigsRecord(named_values))
return Task(
task_type=TASK_TYPE_FIT,
recordset=serde.recordset_to_proto(recordset),
)


def _get_from_task(task: Task) -> Dict[str, Value]:
return named_values_from_proto(task.sa.named_values)
def _get_from_task(task: Task) -> Dict[str, ConfigsRecordValues]:
recordset = serde.recordset_from_proto(task.recordset)
return recordset.get_configs(RECORD_KEY_CONFIGS).data


_secure_aggregation_configuration = {
Expand Down Expand Up @@ -233,22 +248,24 @@ def workflow_with_sec_agg(
if LOG_EXPLAIN:
print(f"\nForwarding encrypted key shares and requesting masked input...")
# Send encrypted secret key shares to clients (plus model parameters)
weights = parameters_to_ndarrays(parameters)
yield {
node_id: _wrap_in_task(
named_values={
KEY_STAGE: STAGE_COLLECT_MASKED_INPUT,
KEY_CIPHERTEXT_LIST: fwd_ciphertexts[nid2sid[node_id]],
KEY_SOURCE_LIST: fwd_srcs[nid2sid[node_id]],
KEY_PARAMETERS: [ndarray_to_bytes(arr) for arr in weights],
}
},
fit_ins=FitIns(
parameters=parameters, config={"drop": nid2sid[node_id] == 0}
),
)
for node_id in surviving_node_ids
}
# Collect masked input from clients
node_messages = yield
surviving_node_ids = [node_id for node_id in node_messages]
# Get shape of vector sent by first client
weights = parameters_to_ndarrays(parameters)
masked_vector = [np.array([0], dtype=int)] + get_zero_parameters(
[w.shape for w in weights]
)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/client/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Middleware layers."""


from .secure_aggregation.secaggplus_middleware import secaggplus_middleware
from .utils import make_ffn

__all__ = [
"make_ffn",
"secaggplus_middleware",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,12 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Secure Aggregation handlers."""


from .handler import SecureAggregationHandler
from .secaggplus_handler import SecAggPlusHandler
from .secaggplus_middleware import secaggplus_middleware

__all__ = [
"SecAggPlusHandler",
"SecureAggregationHandler",
"secaggplus_middleware",
]
Loading

0 comments on commit 55e9004

Please sign in to comment.