Skip to content

Commit

Permalink
Move the history saving logic into the server
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Dec 6, 2023
1 parent 6a3b5f2 commit 3daa7d6
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 61 deletions.
1 change: 0 additions & 1 deletion baselines/flanders/flanders/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def fang_attack(
if old_lambda > threshold and malicious_selected == False:
l = old_lambda * 0.5


# Compute sign vector s
magnitude = []
for i in range(len(w_re)):
Expand Down
1 change: 1 addition & 0 deletions baselines/flanders/flanders/conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ server:
num_malicious: 0
warmup_rounds: 2
sampling: 500
history_dir: clients_params

client:
# client config
43 changes: 5 additions & 38 deletions baselines/flanders/flanders/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,7 @@
import pandas as pd

from typing import Dict, Callable, Optional, Tuple, List
from .models import (
MnistNet,
ToyNN,
roc_auc_multiclass,
test_toy,
train_mnist,
test_mnist,
train_toy
)

from .client import (
CifarClient,
HouseClient,
Expand All @@ -42,7 +34,7 @@
set_sklearn_model_params,
get_sklearn_model_params
)
from .utils import save_results, l2_norm
from .utils import l2_norm, mnist_evaluate
from .server import EnhancedServer

from torch.utils.data import DataLoader
Expand Down Expand Up @@ -86,11 +78,8 @@ def main(cfg: DictConfig) -> None:
print(cfg.server.pool_size)

# Delete old client_params and clients_predicted_params
# TODO: parametrize this
if os.path.exists("clients_params"):
shutil.rmtree("clients_params")
if os.path.exists("clients_predicted_params"):
shutil.rmtree("clients_predicted_params")
if os.path.exists(cfg.server.history_dir):
shutil.rmtree(cfg.server.history_dir)

evaluate_fn = mnist_evaluate

Expand Down Expand Up @@ -135,7 +124,6 @@ def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.nam
distance_function=l2_norm
)


# 5. Start Simulation
# history = fl.simulation.start_simulation(<arguments for simulation>)
history = fl.simulation.start_simulation(
Expand All @@ -149,6 +137,7 @@ def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.nam
client_manager=SimpleClientManager(),
strategy=strategy,
sampling=cfg.server.sampling,
history_dir=cfg.server.history_dir
),
config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds),
strategy=strategy
Expand All @@ -172,27 +161,5 @@ def fit_config(server_round: int) -> Dict[str, Scalar]:
}
return config

def mnist_evaluate(
server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
):
# determine device
device = torch.device("cpu")

model = MnistNet()
set_params(model, parameters)
model.to(device)

testset = MNIST("", train=False, download=True, transform=transforms.ToTensor())
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1)
loss, accuracy, auc = test_mnist(model, testloader, device=device)

#config["id"] = args.exp_num
config["round"] = server_round
config["auc"] = auc
save_results(loss, accuracy, config=config)
print(f"Round {server_round} accuracy: {accuracy} loss: {loss} auc: {auc}")

return loss, {"accuracy": accuracy, "auc": auc}

if __name__ == "__main__":
main()
25 changes: 18 additions & 7 deletions baselines/flanders/flanders/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
warmup_rounds: int,
attack_fn:Optional[Callable],
sampling: int = 0,
history_dir: str = "clients_params",
*args: Any,
**kwargs: Any
) -> None:
Expand All @@ -74,6 +75,7 @@ def __init__(
self.sampling = sampling
self.aggregated_parameters = []
self.params_indexes = []
self.history_dir = history_dir


def fit_round(
Expand Down Expand Up @@ -153,7 +155,7 @@ def fit_round(
params = params[self.params_indexes]

print(f"fit_round 1 - Saving parameters of client {fitres.metrics['cid']} with shape {params.shape}")
save_params(params, fitres.metrics['cid'])
save_params(params, fitres.metrics['cid'], dir=self.history_dir)

# Re-arrange results in the same order as clients' cids impose
print("fit_round - Re-arranging results in the same order as clients' cids impose")
Expand Down Expand Up @@ -186,7 +188,7 @@ def fit_round(
else:
params = flatten_params(parameters_to_ndarrays(fitres.parameters))
print(f"fit_round 2 - Saving parameters of client {fitres.metrics['cid']} with shape {params.shape}")
save_params(params, fitres.metrics['cid'], remove_last=True)
save_params(params, fitres.metrics['cid'], dir=self.history_dir, remove_last=True)
else:
results = ordered_results
others = {}
Expand All @@ -196,10 +198,19 @@ def fit_round(

# Aggregate training results
print("fit_round - Aggregating training results")
aggregated_result: Tuple[
Optional[Parameters],
Dict[str, Scalar],
] = self.strategy.aggregate_fit(server_round, results, failures, clients_state)
aggregated_result = self.strategy.aggregate_fit(server_round, results, failures, clients_state)

parameters_aggregated, metrics_aggregated, malicious_clients_idx = aggregated_result
print(f"fit_round - Malicious clients: {malicious_clients_idx}")
# For clients detected as malicious, set their parameters to be the averaged ones in their files
# otherwise the forecasting in next round won't be reliable
if self.warmup_rounds > server_round:
print(f"fit_round - Saving parameters of clients")
for idx in malicious_clients_idx:
if self.sampling > 0:
new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))[self.params_indexes]
else:
new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))
save_params(new_params, idx, dir=self.history_dir, remove_last=True, rrl=True)

parameters_aggregated, metrics_aggregated = aggregated_result
return parameters_aggregated, metrics_aggregated, (results, failures)
19 changes: 5 additions & 14 deletions baselines/flanders/flanders/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,17 @@ def aggregate_fit(
Aggregated parameters.
metrics_aggregated: Dict[str, Scalar]
Aggregated metrics.
malicious_clients_idx: List[int]
List of malicious clients' cids (indexes).
"""
malicious_clients_idx = []
if server_round > 1:
win = self.window
if server_round < self.window:
win = server_round
M = load_all_time_series(dir="clients_params", window=win)
M = np.transpose(M, (0, 2, 1)) # (clients, params, time)


M_hat = M[:,:,-1].copy()
pred_step = 1
print(f"aggregate_fit - Computing MAR on M {M.shape}")
Expand All @@ -163,22 +165,11 @@ def aggregate_fit(
# Apply FedAvg for the remaining clients
print("aggregate_fit - Applying FedAvg for the remaining clients")
parameters_aggregated, metrics_aggregated = super().aggregate_fit(server_round, results, failures)

# For clients detected as malicious, set their parameters to be the averaged ones in their files
# otherwise the forecasting in next round won't be reliable
if self.warmup_rounds > server_round:
for idx in malicious_clients_idx:
if self.sampling > 0:
new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))[self.params_indexes]
else:
new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))
print(f"aggregate_fit - Saving parameters of client {idx} with shape {new_params.shape}")
save_params(new_params, idx, dir="clients_params", remove_last=True, rrl=True)
else:
# Apply FedAvg on the first round
# Apply FedAvg on every clients' params during the first round
parameters_aggregated, metrics_aggregated = super().aggregate_fit(server_round, results, failures)

return parameters_aggregated, metrics_aggregated
return parameters_aggregated, metrics_aggregated, malicious_clients_idx

def mar(X, pred_step, alpha=1, beta=1, maxiter=100, window=0):
'''
Expand Down
40 changes: 39 additions & 1 deletion baselines/flanders/flanders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@
import numpy as np
import os
import json
import torch
import pandas as pd
from natsort import natsorted
from typing import Dict, Optional, Tuple, List
from flwr.common import (
Parameters,
Scalar,
parameters_to_ndarrays,
NDArrays,
)
from threading import Lock
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from .models import (
MnistNet,
ToyNN,
roc_auc_multiclass,
test_toy,
train_mnist,
test_mnist,
train_toy
)
from .client import set_params

lock = Lock() # if the script is run on multiple processors we need a lock to save the results

Expand Down Expand Up @@ -153,4 +169,26 @@ def evaluate_aggregated(
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics
return loss, metrics

def mnist_evaluate(
server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
):
# determine device
device = torch.device("cpu")

model = MnistNet()
set_params(model, parameters)
model.to(device)

testset = MNIST("", train=False, download=True, transform=transforms.ToTensor())
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1)
loss, accuracy, auc = test_mnist(model, testloader, device=device)

#config["id"] = args.exp_num
config["round"] = server_round
config["auc"] = auc
save_results(loss, accuracy, config=config)
print(f"Round {server_round} accuracy: {accuracy} loss: {loss} auc: {auc}")

return loss, {"accuracy": accuracy, "auc": auc}

0 comments on commit 3daa7d6

Please sign in to comment.