Skip to content

Commit

Permalink
Integrate custom server and fix strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Nov 26, 2023
1 parent 2707ca2 commit 9f86262
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 43 deletions.
6 changes: 4 additions & 2 deletions baselines/flanders/flanders/conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ strategy:

server:
_target_: flanders.server.EnhancedServer
num_rounds: 2
pool_size: 1
num_rounds: 3
pool_size: 2
num_malicious: 0
warmup_rounds: 2
sampling: 500

client:
# client config
28 changes: 21 additions & 7 deletions baselines/flanders/flanders/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def main(cfg: DictConfig) -> None:
print(cfg.dataset.name)
print(cfg.server.pool_size)

# Delete old client_params and clients_predicted_params
# TODO: parametrize this
if os.path.exists("clients_params"):
shutil.rmtree("clients_params")
if os.path.exists("clients_predicted_params"):
shutil.rmtree("clients_predicted_params")

evaluate_fn = mnist_evaluate

# 2. Prepare your dataset
Expand Down Expand Up @@ -118,12 +125,12 @@ def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.nam
on_fit_config_fn=fit_config,
fraction_fit=1,
fraction_evaluate=0, # no federated evaluation
min_fit_clients=1,
min_fit_clients=2,
min_evaluate_clients=0,
warmup_rounds=1,
to_keep=1, # Used in Flanders, MultiKrum, TrimmedMean (in Bulyan it is forced to 1)
min_available_clients=1, # All clients should be available
window=1, # Used in Flanders
warmup_rounds=2,
to_keep=2, # Used in Flanders, MultiKrum, TrimmedMean (in Bulyan it is forced to 1)
min_available_clients=2, # All clients should be available
window=2, # Used in Flanders
sampling=1, # Used in Flanders
)

Expand All @@ -132,9 +139,16 @@ def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.nam
# history = fl.simulation.start_simulation(<arguments for simulation>)
history = fl.simulation.start_simulation(
client_fn=client_fn,
client_resources={"num_cpus": 1},
client_resources={"num_cpus": 10},
num_clients=cfg.server.pool_size,
server=EnhancedServer(num_malicious=cfg.server.num_malicious, attack_fn=None, client_manager=SimpleClientManager),
server=EnhancedServer(
warmup_rounds=cfg.server.warmup_rounds,
num_malicious=cfg.server.num_malicious,
attack_fn=None,
client_manager=SimpleClientManager(),
strategy=strategy,
sampling=cfg.server.sampling,
),
config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds),
strategy=strategy
)
Expand Down
54 changes: 35 additions & 19 deletions baselines/flanders/flanders/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Optionally, also define a new Server class (please note this is not needed in most
settings).
"""
from flwr.server import Server
from flwr.server.server import fit_clients, Server

import concurrent.futures
import timeit
Expand Down Expand Up @@ -58,13 +58,22 @@ class EnhancedServer(Server):
def __init__(
self,
num_malicious: int,
warmup_rounds: int,
attack_fn:Optional[Callable],
sampling: int = 0,
*args: Any,
**kwargs: Any
) -> None:
"""Initialize."""

# TODO: Move all the parameters saving logic from the strategy to the server
super().__init__(*args, **kwargs)
self.num_malicious = num_malicious
self.warmup_rounds = warmup_rounds
self.attack_fn = attack_fn
self.sampling = sampling
self.aggregated_parameters = []
self.params_indexes = []


def fit_round(
Expand All @@ -83,22 +92,24 @@ def fit_round(
)

# Randomly decide which client is malicious
if server_round > self.warmup_rounds:
self.malicious_selected = np.random.choice(
[proxy.cid for proxy, ins in client_instructions], size=self.num_malicious, replace=False
)
log(
DEBUG,
"fit_round %s: malicious clients selected %s",
server_round,
self.malicious_selected,
)
# Save instruction for malicious clients into FitIns
for proxy, ins in client_instructions:
if proxy.cid in self.malicious_selected:
ins["malicious"] = True
else:
ins["malicious"] = False
size = self.num_malicious
if self.warmup_rounds > server_round:
size = 0
self.malicious_selected = np.random.choice(
[proxy.cid for proxy, _ in client_instructions], size=size, replace=False
)
log(
DEBUG,
"fit_round %s: malicious clients selected %s",
server_round,
self.malicious_selected,
)
# Save instruction for malicious clients into FitIns
for proxy, ins in client_instructions:
if proxy.cid in self.malicious_selected:
ins.config["malicious"] = True
else:
ins.config["malicious"] = False

if not client_instructions:
log(INFO, "fit_round %s: no clients selected, cancel", server_round)
Expand All @@ -112,10 +123,10 @@ def fit_round(
)

# Collect `fit` results from all clients participating in this round
results, failures = super.fit_clients(
results, failures = fit_clients(
client_instructions=client_instructions,
max_workers=self.max_workers,
timeout=timeout,
timeout=timeout
)
log(
DEBUG,
Expand All @@ -141,9 +152,11 @@ def fit_round(

params = params[self.params_indexes]

print(f"fit_round 1 - Saving parameters of client {fitres.metrics['cid']} with shape {params.shape}")
save_params(params, fitres.metrics['cid'])

# Re-arrange results in the same order as clients' cids impose
print("fit_round - Re-arranging results in the same order as clients' cids impose")
ordered_results[int(fitres.metrics['cid'])] = (proxy, fitres)

# Initialize aggregated_parameters if it is the first round
Expand All @@ -155,6 +168,7 @@ def fit_round(

# Apply attack function
if self.attack_fn is not None and server_round > self.warmup_rounds:
print("fit_round - Applying attack function")
results, others = self.attack_fn(
ordered_results, clients_state, magnitude=self.magnitude,
w_re=self.aggregated_parameters, malicious_selected=self.malicious_selected,
Expand All @@ -171,6 +185,7 @@ def fit_round(
params = flatten_params(parameters_to_ndarrays(fitres.parameters))[self.params_indexes]
else:
params = flatten_params(parameters_to_ndarrays(fitres.parameters))
print(f"fit_round 2 - Saving parameters of client {fitres.metrics['cid']} with shape {params.shape}")
save_params(params, fitres.metrics['cid'], remove_last=True)
else:
results = ordered_results
Expand All @@ -180,6 +195,7 @@ def fit_round(
clients_state = {k: clients_state[k] for k in sorted(clients_state)}

# Aggregate training results
print("fit_round - Aggregating training results")
aggregated_result: Tuple[
Optional[Parameters],
Dict[str, Scalar],
Expand Down
32 changes: 26 additions & 6 deletions baselines/flanders/flanders/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ def __init__(
fit_metrics_aggregation_fn = fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
)
self.dataset_name = dataset_name
self.attack_name = attack_name
self.iid = iid
self.warmup_rounds = warmup_rounds
self.to_keep = to_keep
self.attack_fn = attack_fn
self.window = window
self.maxiter = maxiter
self.sampling = sampling
self.alpha = alpha
self.beta = beta
self.params_indexes = None
self.malicious_selected = False


def aggregate_fit(
Expand Down Expand Up @@ -118,31 +131,37 @@ def aggregate_fit(
win = self.window
if server_round < self.window:
win = server_round
M = load_all_time_series(dir="strategy/clients_params", window=win)
M = load_all_time_series(dir="clients_params", window=win)
M = np.transpose(M, (0, 2, 1)) # (clients, params, time)


M_hat = M[:,:,-1].copy()
pred_step = 1

print(f"aggregate_fit - Computing MAR on M {M.shape}")
Mr = mar(M[:,:,:-1], pred_step, maxiter=self.maxiter, alpha=self.alpha, beta=self.beta)

# TODO: generalize this to user-selected distance functions
print("aggregate_fit - Computing anomaly scores")
delta = np.subtract(M_hat, Mr[:,:,0])
anomaly_scores = np.sum(delta**2,axis=-1)**(1./2)
print(f"aggregate_fit - Anomaly scores: {anomaly_scores}")

print("aggregate_fit - Selecting good clients")
good_clients_idx = sorted(np.argsort(anomaly_scores)[:self.to_keep])
malicious_clients_idx = sorted(np.argsort(anomaly_scores)[self.to_keep:])
results = np.array(results)[good_clients_idx].tolist()

print(f"aggregate_fit - Good clients: {good_clients_idx}")

print(f"aggregate_fit - clients_state: {clients_state}")
for idx in good_clients_idx:
if clients_state[idx]:
if clients_state[str(idx)]:
self.malicious_selected = True
break
else:
self.malicious_selected = False

# Apply FedAvg for the remaining clients
print("aggregate_fit - Applying FedAvg for the remaining clients")
parameters_aggregated, metrics_aggregated = super().aggregate_fit(server_round, results, failures)

# For clients detected as malicious, set their parameters to be the averaged ones in their files
Expand All @@ -153,7 +172,8 @@ def aggregate_fit(
new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))[self.params_indexes]
else:
new_params = flatten_params(parameters_to_ndarrays(parameters_aggregated))
save_params(new_params, idx, remove_last=True, rrl=True)
print(f"aggregate_fit - Saving parameters of client {idx} with shape {new_params.shape}")
save_params(new_params, idx, dir="clients_params", remove_last=True, rrl=True)
else:
# Apply FedAvg on the first round
parameters_aggregated, metrics_aggregated = super().aggregate_fit(server_round, results, failures)
Expand All @@ -174,7 +194,7 @@ def mar(X, pred_step, alpha=1, beta=1, maxiter=100, window=0):
B = np.random.randn(n, n)
X_norm = (X-np.min(X))/np.max(X)

for it in range(maxiter):
for _ in range(maxiter):
temp0 = B.T @ B
temp1 = np.zeros((m, m))
temp2 = np.zeros((m, m))
Expand Down
20 changes: 11 additions & 9 deletions baselines/flanders/flanders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

lock = Lock() # if the script is run on multiple processors we need a lock to save the results

def save_params(parameters, cid, remove_last=False, rrl=False):
def save_params(parameters, cid, dir="clients_params", remove_last=False, rrl=False):
"""
Args:
- parameters (ndarray): decoded parameters to append at the end of the file
Expand All @@ -26,21 +26,22 @@ def save_params(parameters, cid, remove_last=False, rrl=False):
- rrl (bool): if True, remove the last saved parameters and replace with the ones saved before this round
"""
new_params = parameters
# Save parameters in client_params/cid_params
path = f"strategy/clients_params/{cid}_params.npy"
if os.path.exists("strategy/clients_params") == False:
os.mkdir("strategy/clients_params")
if os.path.exists(path):
# Save parameters in clients_params/cid_params
path_file = f"{dir}/{cid}_params.npy"
if os.path.exists(dir) == False:
os.mkdir(dir)
if os.path.exists(path_file):
# load old parameters
old_params = np.load(path, allow_pickle=True)
old_params = np.load(path_file, allow_pickle=True)
if remove_last:
old_params = old_params[:-1]
if rrl:
new_params = old_params[-1]
# add new parameters
new_params = np.vstack((old_params, new_params))
print(f"new_params shape of {cid}: {new_params.shape}")
# save parameters
np.save(path, new_params)
np.save(path_file, new_params)


def save_predicted_params(parameters, cid):
Expand Down Expand Up @@ -80,7 +81,7 @@ def save_results(loss, accuracy, config=None):
df.to_csv(csv_path, index=False, header=True)


def load_all_time_series(dir="", window=0):
def load_all_time_series(dir="clients_params", window=0):
"""
Load all time series in order to have a tensor of shape (m,T,n)
where:
Expand All @@ -93,6 +94,7 @@ def load_all_time_series(dir="", window=0):
data = []
for file in files:
data.append(np.load(os.path.join(dir, file), allow_pickle=True))

return np.array(data)[:,-window:,:]


Expand Down

0 comments on commit 9f86262

Please sign in to comment.