diff --git a/baselines/flanders/flanders/attacks.py b/baselines/flanders/flanders/attacks.py index 19492a8c3828..7cd2df025002 100644 --- a/baselines/flanders/flanders/attacks.py +++ b/baselines/flanders/flanders/attacks.py @@ -135,7 +135,6 @@ def fang_attack( if old_lambda > threshold and malicious_selected == False: l = old_lambda * 0.5 - # Compute sign vector s magnitude = [] for i in range(len(w_re)): diff --git a/baselines/flanders/flanders/conf/base.yaml b/baselines/flanders/flanders/conf/base.yaml index 0408c52df247..c8a105100204 100644 --- a/baselines/flanders/flanders/conf/base.yaml +++ b/baselines/flanders/flanders/conf/base.yaml @@ -31,6 +31,7 @@ server: num_malicious: 0 warmup_rounds: 2 sampling: 500 + history_dir: clients_params client: # client config \ No newline at end of file diff --git a/baselines/flanders/flanders/main.py b/baselines/flanders/flanders/main.py index 07469b71b3b9..0c20867d8c33 100644 --- a/baselines/flanders/flanders/main.py +++ b/baselines/flanders/flanders/main.py @@ -22,15 +22,7 @@ import pandas as pd from typing import Dict, Callable, Optional, Tuple, List -from .models import ( - MnistNet, - ToyNN, - roc_auc_multiclass, - test_toy, - train_mnist, - test_mnist, - train_toy -) + from .client import ( CifarClient, HouseClient, @@ -42,7 +34,7 @@ set_sklearn_model_params, get_sklearn_model_params ) -from .utils import save_results, l2_norm +from .utils import l2_norm, mnist_evaluate from .server import EnhancedServer from torch.utils.data import DataLoader @@ -86,11 +78,8 @@ def main(cfg: DictConfig) -> None: print(cfg.server.pool_size) # Delete old client_params and clients_predicted_params - # TODO: parametrize this - if os.path.exists("clients_params"): - shutil.rmtree("clients_params") - if os.path.exists("clients_predicted_params"): - shutil.rmtree("clients_predicted_params") + if os.path.exists(cfg.server.history_dir): + shutil.rmtree(cfg.server.history_dir) evaluate_fn = mnist_evaluate @@ -135,7 +124,6 @@ def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.nam distance_function=l2_norm ) - # 5. Start Simulation # history = fl.simulation.start_simulation() history = fl.simulation.start_simulation( @@ -149,6 +137,7 @@ def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.nam client_manager=SimpleClientManager(), strategy=strategy, sampling=cfg.server.sampling, + history_dir=cfg.server.history_dir ), config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds), strategy=strategy @@ -172,27 +161,5 @@ def fit_config(server_round: int) -> Dict[str, Scalar]: } return config -def mnist_evaluate( - server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] -): - # determine device - device = torch.device("cpu") - - model = MnistNet() - set_params(model, parameters) - model.to(device) - - testset = MNIST("", train=False, download=True, transform=transforms.ToTensor()) - testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1) - loss, accuracy, auc = test_mnist(model, testloader, device=device) - - #config["id"] = args.exp_num - config["round"] = server_round - config["auc"] = auc - save_results(loss, accuracy, config=config) - print(f"Round {server_round} accuracy: {accuracy} loss: {loss} auc: {auc}") - - return loss, {"accuracy": accuracy, "auc": auc} - if __name__ == "__main__": main() \ No newline at end of file diff --git a/baselines/flanders/flanders/server.py b/baselines/flanders/flanders/server.py index 96b19016b63c..791306504273 100644 --- a/baselines/flanders/flanders/server.py +++ b/baselines/flanders/flanders/server.py @@ -61,6 +61,7 @@ def __init__( warmup_rounds: int, attack_fn:Optional[Callable], sampling: int = 0, + history_dir: str = "clients_params", *args: Any, **kwargs: Any ) -> None: @@ -74,6 +75,7 @@ def __init__( self.sampling = sampling self.aggregated_parameters = [] self.params_indexes = [] + self.history_dir = history_dir def fit_round( @@ -153,7 +155,7 @@ def fit_round( params = params[self.params_indexes] print(f"fit_round 1 - Saving parameters of client {fitres.metrics['cid']} with shape {params.shape}") - save_params(params, fitres.metrics['cid']) + save_params(params, fitres.metrics['cid'], dir=self.history_dir) # Re-arrange results in the same order as clients' cids impose print("fit_round - Re-arranging results in the same order as clients' cids impose") @@ -186,7 +188,7 @@ def fit_round( else: params = flatten_params(parameters_to_ndarrays(fitres.parameters)) print(f"fit_round 2 - Saving parameters of client {fitres.metrics['cid']} with shape {params.shape}") - save_params(params, fitres.metrics['cid'], remove_last=True) + save_params(params, fitres.metrics['cid'], dir=self.history_dir, remove_last=True) else: results = ordered_results others = {} @@ -196,10 +198,19 @@ def fit_round( # Aggregate training results print("fit_round - Aggregating training results") - aggregated_result: Tuple[ - Optional[Parameters], - Dict[str, Scalar], - ] = self.strategy.aggregate_fit(server_round, results, failures, clients_state) + aggregated_result = self.strategy.aggregate_fit(server_round, results, failures, clients_state) + + parameters_aggregated, metrics_aggregated, malicious_clients_idx = aggregated_result + print(f"fit_round - Malicious clients: {malicious_clients_idx}") + # For clients detected as malicious, set their parameters to be the averaged ones in their files + # otherwise the forecasting in next round won't be reliable + if self.warmup_rounds > server_round: + print(f"fit_round - Saving parameters of clients") + for idx in malicious_clients_idx: + if self.sampling > 0: + new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))[self.params_indexes] + else: + new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated)) + save_params(new_params, idx, dir=self.history_dir, remove_last=True, rrl=True) - parameters_aggregated, metrics_aggregated = aggregated_result return parameters_aggregated, metrics_aggregated, (results, failures) diff --git a/baselines/flanders/flanders/strategy.py b/baselines/flanders/flanders/strategy.py index 94a07ba3a28e..d5d0b9d82049 100644 --- a/baselines/flanders/flanders/strategy.py +++ b/baselines/flanders/flanders/strategy.py @@ -128,7 +128,10 @@ def aggregate_fit( Aggregated parameters. metrics_aggregated: Dict[str, Scalar] Aggregated metrics. + malicious_clients_idx: List[int] + List of malicious clients' cids (indexes). """ + malicious_clients_idx = [] if server_round > 1: win = self.window if server_round < self.window: @@ -136,7 +139,6 @@ def aggregate_fit( M = load_all_time_series(dir="clients_params", window=win) M = np.transpose(M, (0, 2, 1)) # (clients, params, time) - M_hat = M[:,:,-1].copy() pred_step = 1 print(f"aggregate_fit - Computing MAR on M {M.shape}") @@ -163,22 +165,11 @@ def aggregate_fit( # Apply FedAvg for the remaining clients print("aggregate_fit - Applying FedAvg for the remaining clients") parameters_aggregated, metrics_aggregated = super().aggregate_fit(server_round, results, failures) - - # For clients detected as malicious, set their parameters to be the averaged ones in their files - # otherwise the forecasting in next round won't be reliable - if self.warmup_rounds > server_round: - for idx in malicious_clients_idx: - if self.sampling > 0: - new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))[self.params_indexes] - else: - new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated)) - print(f"aggregate_fit - Saving parameters of client {idx} with shape {new_params.shape}") - save_params(new_params, idx, dir="clients_params", remove_last=True, rrl=True) else: - # Apply FedAvg on the first round + # Apply FedAvg on every clients' params during the first round parameters_aggregated, metrics_aggregated = super().aggregate_fit(server_round, results, failures) - return parameters_aggregated, metrics_aggregated + return parameters_aggregated, metrics_aggregated, malicious_clients_idx def mar(X, pred_step, alpha=1, beta=1, maxiter=100, window=0): ''' diff --git a/baselines/flanders/flanders/utils.py b/baselines/flanders/flanders/utils.py index 1eb1347b2f22..eeafc6de96d6 100644 --- a/baselines/flanders/flanders/utils.py +++ b/baselines/flanders/flanders/utils.py @@ -5,6 +5,7 @@ import numpy as np import os import json +import torch import pandas as pd from natsort import natsorted from typing import Dict, Optional, Tuple, List @@ -12,8 +13,23 @@ Parameters, Scalar, parameters_to_ndarrays, + NDArrays, ) from threading import Lock +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import MNIST + +from .models import ( + MnistNet, + ToyNN, + roc_auc_multiclass, + test_toy, + train_mnist, + test_mnist, + train_toy +) +from .client import set_params lock = Lock() # if the script is run on multiple processors we need a lock to save the results @@ -153,4 +169,26 @@ def evaluate_aggregated( if eval_res is None: return None loss, metrics = eval_res - return loss, metrics \ No newline at end of file + return loss, metrics + +def mnist_evaluate( + server_round: int, parameters: NDArrays, config: Dict[str, Scalar] +): + # determine device + device = torch.device("cpu") + + model = MnistNet() + set_params(model, parameters) + model.to(device) + + testset = MNIST("", train=False, download=True, transform=transforms.ToTensor()) + testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1) + loss, accuracy, auc = test_mnist(model, testloader, device=device) + + #config["id"] = args.exp_num + config["round"] = server_round + config["auc"] = auc + save_results(loss, accuracy, config=config) + print(f"Round {server_round} accuracy: {accuracy} loss: {loss} auc: {auc}") + + return loss, {"accuracy": accuracy, "auc": auc}