Skip to content

Commit

Permalink
sa mw
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 30, 2024
1 parent 43dfa50 commit 2a0b085
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 104 deletions.
2 changes: 2 additions & 0 deletions src/py/flwr/client/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Middleware layers."""


from .secure_aggregation.secaggplus_middleware import secaggplus_middleware
from .utils import make_ffn

__all__ = [
"make_ffn",
"secaggplus_middleware",
]
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,40 @@ class SecAggPlusState:
ss2_dict: Dict[int, bytes] = field(default_factory=dict)
public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict)

def __init__(self, **kwargs: ConfigsRecordValues) -> None:
for k, v in kwargs.items():
if k.endswith(":V"):
continue
new_v: Any = v
if k.endswith(":K"):
keys = cast(List[int], v)
values = cast(List[bytes], kwargs[k[:-2] + ":V"])
if len(values) > len(keys):
updated_values = [
tuple(values[i : i + 2]) for i in range(0, len(values), 2)
]
new_v = dict(zip(keys, updated_values))
else:
new_v = dict(zip(keys, values))
self.__setattr__(k, new_v)

def to_dict(self) -> Dict[str, ConfigsRecordValues]:
"""Convert the state to a dictionary."""
ret = vars(self)
for k in ret.keys():
if isinstance(ret[k], dict):
# Replace dict with two lists
v = cast(Dict[str, Any], ret.pop(k))
ret[f"{k}:K"] = list(v.keys())
if k == "public_keys_dict":
v_list: List[bytes] = []
for b1_b2 in cast(List[Tuple[bytes, bytes]], v.values()):
v_list.extend(b1_b2)
ret[f"{k}:V"] = v_list
else:
ret[f"{k}:V"] = list(v.values())
return ret


def _get_fit_fn(
msg: Message, call_next: FlowerCallable, ctxt: Context
Expand All @@ -128,9 +162,9 @@ def fit() -> FitRes:

def secaggplus_middleware(
msg: Message,
call_next: FlowerCallable,
ctxt: Context,
) -> Dict[str, ConfigsRecordValues]:
call_next: FlowerCallable,
) -> Message:
"""Handle incoming message and return results, following the SecAgg+ protocol."""
# Ignore non-fit messages
if msg.metadata.task_type != TASK_TYPE_FIT:
Expand All @@ -139,7 +173,8 @@ def secaggplus_middleware(
# Retrieve state
if RECORD_KEY_STATE not in ctxt.state.configs:
ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord({}))
state = SecAggPlusState(**ctxt.state.get_configs(RECORD_KEY_STATE).data)
state_dict = ctxt.state.get_configs(RECORD_KEY_STATE).data
state = SecAggPlusState(**state_dict)

# Retrieve configs
configs = msg.message.get_configs(RECORD_KEY_CONFIGS).data
Expand All @@ -150,7 +185,7 @@ def secaggplus_middleware(
# Update the current stage
state.current_stage = cast(str, configs.pop(KEY_STAGE))

# Check the validity of the `named_values` based on the current stage
# Check the validity of the configs based on the current stage
check_configs(state.current_stage, configs)

# Execute
Expand All @@ -171,14 +206,8 @@ def secaggplus_middleware(

# Return message
return Message(
metadata=Metadata(
run_id="",
task_id="",
group_id="",
ttl="",
task_type=TASK_TYPE_FIT,
),
message=RecordSet(configs={RECORD_KEY_CONFIGS, ConfigsRecord(res, False)}),
metadata=Metadata(0, "", "", "", TASK_TYPE_FIT),
message=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}),
)


Expand Down Expand Up @@ -341,31 +370,31 @@ def _setup(

# pylint: disable-next=too-many-locals
def _share_keys(
state: SecAggPlusState, named_values: Dict[str, ConfigsRecordValues]
state: SecAggPlusState, configs: Dict[str, ConfigsRecordValues]
) -> Dict[str, ConfigsRecordValues]:
named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], named_values)
named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
log(INFO, "Client %d: starting stage 1...", state.sid)
state.public_keys_dict = key_dict

# Check if the size is larger than threshold
if len(state.public_keys_dict) < state.threshold:
raise Exception("Available neighbours number smaller than threshold")
raise ValueError("Available neighbours number smaller than threshold")

# Check if all public keys are unique
pk_list: List[bytes] = []
for pk1, pk2 in state.public_keys_dict.values():
pk_list.append(pk1)
pk_list.append(pk2)
if len(set(pk_list)) != len(pk_list):
raise Exception("Some public keys are identical")
raise ValueError("Some public keys are identical")

# Check if public keys of this client are correct in the dictionary
if (
state.public_keys_dict[state.sid][0] != state.pk1
or state.public_keys_dict[state.sid][1] != state.pk2
):
raise Exception(
raise ValueError(
"Own public keys are displayed in dict incorrectly, should not happen!"
)

Expand Down Expand Up @@ -412,7 +441,7 @@ def _collect_masked_input(
ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST])
srcs = cast(List[int], configs[KEY_SOURCE_LIST])
if len(ciphertexts) + 1 < state.threshold:
raise Exception("Not enough available neighbour clients.")
raise ValueError("Not enough available neighbour clients.")

# Decrypt ciphertexts, verify their sources, and store shares.
for src, ciphertext in zip(srcs, ciphertexts):
Expand All @@ -428,7 +457,7 @@ def _collect_masked_input(
f"from {actual_src} instead of {src}."
)
if dst != state.sid:
ValueError(
raise ValueError(
f"Client {state.sid}: received an encrypted message"
f"for Client {dst} from Client {src}."
)
Expand Down Expand Up @@ -488,7 +517,7 @@ def _unmask(
# Send private mask seed share for every avaliable client (including itclient)
# Send first private key share for building pairwise mask for every dropped client
if len(active_sids) < state.threshold:
raise Exception("Available neighbours number smaller than threshold")
raise ValueError("Available neighbours number smaller than threshold")

sids, shares = [], []
sids += active_sids
Expand Down
Loading

0 comments on commit 2a0b085

Please sign in to comment.