From 6a3b5f265d9a9e2d28e854c2e07b9c3133cc20a2 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Mon, 27 Nov 2023 12:50:34 +0100 Subject: [PATCH] Generalize and parametrize distance functions to compute anomaly scores --- baselines/flanders/flanders/main.py | 13 +++++++------ baselines/flanders/flanders/strategy.py | 12 ++++++------ baselines/flanders/flanders/utils.py | 20 ++++++++++++++++++++ 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/baselines/flanders/flanders/main.py b/baselines/flanders/flanders/main.py index 1dde35fd6bda..07469b71b3b9 100644 --- a/baselines/flanders/flanders/main.py +++ b/baselines/flanders/flanders/main.py @@ -42,7 +42,7 @@ set_sklearn_model_params, get_sklearn_model_params ) -from .utils import save_results +from .utils import save_results, l2_norm from .server import EnhancedServer from torch.utils.data import DataLoader @@ -124,14 +124,15 @@ def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.nam evaluate_fn=evaluate_fn, on_fit_config_fn=fit_config, fraction_fit=1, - fraction_evaluate=0, # no federated evaluation + fraction_evaluate=0, min_fit_clients=2, min_evaluate_clients=0, warmup_rounds=2, - to_keep=2, # Used in Flanders, MultiKrum, TrimmedMean (in Bulyan it is forced to 1) - min_available_clients=2, # All clients should be available - window=2, # Used in Flanders - sampling=1, # Used in Flanders + to_keep=2, + min_available_clients=2, + window=2, + sampling=1, + distance_function=l2_norm ) diff --git a/baselines/flanders/flanders/strategy.py b/baselines/flanders/flanders/strategy.py index c145e0640338..94a07ba3a28e 100644 --- a/baselines/flanders/flanders/strategy.py +++ b/baselines/flanders/flanders/strategy.py @@ -63,7 +63,8 @@ def __init__( maxiter: int = 100, sampling: str = None, alpha: float = 1, - beta: float = 1 + beta: float = 1, + distance_function: Callable = None, ) -> None: """ Parameters @@ -99,6 +100,7 @@ def __init__( self.beta = beta self.params_indexes = None self.malicious_selected = False + self.distance_function = distance_function def aggregate_fit( @@ -109,7 +111,7 @@ def aggregate_fit( clients_state: List[bool], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """ - Apply MAR forecasting to exclude malicious clients from the average. + Apply MAR forecasting to exclude malicious clients from FedAvg. Parameters ---------- @@ -140,10 +142,8 @@ def aggregate_fit( print(f"aggregate_fit - Computing MAR on M {M.shape}") Mr = mar(M[:,:,:-1], pred_step, maxiter=self.maxiter, alpha=self.alpha, beta=self.beta) - # TODO: generalize this to user-selected distance functions print("aggregate_fit - Computing anomaly scores") - delta = np.subtract(M_hat, Mr[:,:,0]) - anomaly_scores = np.sum(delta**2,axis=-1)**(1./2) + anomaly_scores = self.distance_function(M_hat, Mr[:,:,0]) print(f"aggregate_fit - Anomaly scores: {anomaly_scores}") print("aggregate_fit - Selecting good clients") @@ -151,7 +151,7 @@ def aggregate_fit( malicious_clients_idx = sorted(np.argsort(anomaly_scores)[self.to_keep:]) results = np.array(results)[good_clients_idx].tolist() print(f"aggregate_fit - Good clients: {good_clients_idx}") - + print(f"aggregate_fit - clients_state: {clients_state}") for idx in good_clients_idx: if clients_state[str(idx)]: diff --git a/baselines/flanders/flanders/utils.py b/baselines/flanders/flanders/utils.py index 6c704a374d68..1eb1347b2f22 100644 --- a/baselines/flanders/flanders/utils.py +++ b/baselines/flanders/flanders/utils.py @@ -17,6 +17,26 @@ lock = Lock() # if the script is run on multiple processors we need a lock to save the results +def l2_norm(true_matrix, predicted_matrix): + """ + Compute the l2 norm between two matrices. + + Parameters + ---------- + true_matrix : ndarray + The true matrix. + predicted_matrix : ndarray + The predicted matrix by MAR. + + Returns + ------- + anomaly_scores : ndarray + 1-d array of anomaly scores. + """ + delta = np.subtract(true_matrix, predicted_matrix) + anomaly_scores = np.sum(delta**2,axis=-1)**(1./2) + return anomaly_scores + def save_params(parameters, cid, dir="clients_params", remove_last=False, rrl=False): """ Args: