-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(secure aggregation): aggregation function WIP
- Loading branch information
1 parent
43a2c93
commit 6f81222
Showing
1 changed file
with
124 additions
and
0 deletions.
There are no files selected for viewing
124 changes: 124 additions & 0 deletions
124
openfl/interface/aggregation_functions/secure_aggregation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
"""Secure aggregation module.""" | ||
|
||
from typing import Iterator, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from openfl.interface.aggregation_functions.core import AggregationFunction | ||
from openfl.utilities import LocalTensor | ||
|
||
|
||
class SecureAggregation(AggregationFunction): | ||
""" | ||
SecureAggregation class for performing secure aggregation of local tensors. | ||
""" | ||
def call( | ||
self, | ||
local_tensors: list[LocalTensor], | ||
db_iterator: Iterator[pd.Series], | ||
tensor_name: str, | ||
fl_round: int, | ||
tags: Tuple[str], | ||
) -> np.ndarray: | ||
""" | ||
Perform secure aggregation by calling the appropriate aggregation | ||
methods based on the tags. | ||
Args: | ||
local_tensors (list[LocalTensor]): List of local tensors to be | ||
aggregated. | ||
db_iterator (Iterator[pd.Series]): Iterator over the database | ||
series. | ||
tensor_name (str): Name of the tensor. | ||
fl_round (int): Federated learning round number. | ||
tags (Tuple[str]): Tags indicating the type of aggregation to | ||
perform. | ||
Returns: | ||
np.ndarray: Aggregated tensor. | ||
""" | ||
self._aggregate_public_keys(local_tensors, tags) | ||
self._aggregate_ciphertexts(local_tensors, tags) | ||
self._aggregate_deciphertexts(local_tensors, tags) | ||
|
||
def _aggregate_public_keys( | ||
self, | ||
local_tensors: list[LocalTensor], | ||
tags: Tuple[str], | ||
): | ||
""" | ||
Aggregate public keys from the local tensors. | ||
Args: | ||
local_tensors (list[LocalTensor]): List of local tensors | ||
containing public keys. | ||
tags (Tuple[str]): Tags indicating the type of aggregation to | ||
perform. | ||
Returns: | ||
np.ndarray: Aggregated public keys tensor. | ||
""" | ||
aggregated_tensor = [] | ||
if "public_key" in tags: | ||
# Setting indices for the collaborators. | ||
index = 1 | ||
for tensor in local_tensors: | ||
# tensor[0] is public_key_1 | ||
# tensor[1] is public_key_2 | ||
aggregated_tensor.append( | ||
[index, tensor.tensor[0], tensor.tensor[1]] | ||
) | ||
index += 1 | ||
|
||
return np.array(aggregated_tensor) | ||
|
||
def _aggregate_ciphertexts( | ||
self, | ||
local_tensors: list[LocalTensor], | ||
tags: Tuple[str], | ||
): | ||
""" | ||
Aggregate ciphertexts from the local tensors. | ||
Args: | ||
local_tensors (list[LocalTensor]): List of local tensors | ||
containing ciphertexts. | ||
tags (Tuple[str]): Tags indicating the type of aggregation to | ||
perform. | ||
Returns: | ||
np.ndarray: Aggregated ciphertexts tensor. | ||
""" | ||
aggregated_tensor = [] | ||
if "ciphertext" in tags: | ||
aggregated_tensor = [tensor.tensor for tensor in local_tensors] | ||
|
||
return np.array(aggregated_tensor) | ||
|
||
def _aggregate_deciphertexts( | ||
self, | ||
local_tensors: list[LocalTensor], | ||
tags: Tuple[str], | ||
) -> np.ndarray: | ||
""" | ||
Aggregate deciphertexts from the local tensors. | ||
Args: | ||
local_tensors (list[LocalTensor]): List of local tensors | ||
containing deciphertexts. | ||
tags (Tuple[str]): Tags indicating the type of aggregation to | ||
perform. | ||
Returns: | ||
np.ndarray: Aggregated deciphertexts tensor. | ||
""" | ||
aggregated_tensor = [] | ||
if "deciphertext" in tags: | ||
aggregated_tensor = [tensor.tensor for tensor in local_tensors] | ||
|
||
return np.array(aggregated_tensor) |