Skip to content

Commit

Permalink
Add malicious clients sampling and various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Nov 22, 2023
1 parent d7b400e commit 2707ca2
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 30 deletions.
6 changes: 3 additions & 3 deletions baselines/flanders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ California Housing (linear regression):

| Dataset | # of clients | Clients per round | # of rounds | Batch size | Learning rate | $\lambda_1$ | $\lambda_2$ | Optimizer | Dropout | Alpha | Beta | # of clients to keep | Sampling |
| -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- |
| Income | 100 | 100 | 100 | \ | \ | 1.0 | 0.0 | CCD | \ | 0.0 | 0.0 | 1 | \ |
| MNIST | 100 | 100 | 100 | 32 | $10^{-3}$ | \ | \ | Adam | 0.2 | 0.0 | 0.0 | 1 | \ |
| California Housing | 100 | 100 | 100 | \ | \ | 0.5 | 0.5 | CCD | \ | 0.0 | 0.0 | 1 | \ |
| Income | 100 | 100 | 50 | \ | \ | 1.0 | 0.0 | CCD | \ | 0.0 | 0.0 | 1 | \ |
| MNIST | 100 | 100 | 50 | 32 | $10^{-3}$ | \ | \ | Adam | 0.2 | 0.0 | 0.0 | 1 | \ |
| California Housing | 100 | 100 | 50 | \ | \ | 0.5 | 0.5 | CCD | \ | 0.0 | 0.0 | 1 | \ |

Where $\lambda_1$ and $\lambda_2$ are Lasso and Ridge regularization terms.

Expand Down
2 changes: 1 addition & 1 deletion baselines/flanders/flanders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Template baseline package."""
"""FLANDERS package."""
22 changes: 15 additions & 7 deletions baselines/flanders/flanders/conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# this is the config that will be loaded as default by main.py
# Please follow the provided structure (this will ensuring all baseline follow
# a similar configuration structure and hence be easy to customise)
defaults:
- _self_
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled
#defaults:
# - _self_
# - override hydra/hydra_logging: disabled
# - override hydra/job_logging: disabled

hydra:
output_subdir: null
Expand All @@ -14,13 +14,21 @@ hydra:

dataset:
# dataset config
name: mnist

model:
# model config
_target_: flanders.models.MnistNet

strategy:
_target_: # points to your strategy (either custom or exiting in Flower)
# rest of strategy config
_target_: flanders.strategy.Flanders
attack_fn: null

server:
_target_: flanders.server.EnhancedServer
num_rounds: 2
pool_size: 1
num_malicious: 0

client:
# client config
# client config
2 changes: 1 addition & 1 deletion baselines/flanders/flanders/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __len__(self) -> int:
return len(self.data)


def get_cifar_10(path_to_data="datasets/cifar_nn/data"):
def get_cifar_10(path_to_data="datasets/cifar_10/data"):
"""Downloads CIFAR10 dataset and generates a unified training set (it will
be partitioned later using the LDA partitioning mechanism."""

Expand Down
54 changes: 43 additions & 11 deletions baselines/flanders/flanders/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import flwr as fl
from flwr.common.typing import Scalar
from flwr.server.client_manager import SimpleClientManager
import torch
import os
import random
Expand Down Expand Up @@ -42,6 +43,7 @@
get_sklearn_model_params
)
from .utils import save_results
from .server import EnhancedServer

from torch.utils.data import DataLoader
from torchvision import transforms
Expand Down Expand Up @@ -80,7 +82,8 @@ def main(cfg: DictConfig) -> None:
# 1. Print parsed config
print(OmegaConf.to_yaml(cfg))

print(cfg.dataset)
print(cfg.dataset.name)
print(cfg.server.pool_size)

evaluate_fn = mnist_evaluate

Expand All @@ -92,27 +95,47 @@ def main(cfg: DictConfig) -> None:
# be a location in the file system, a list of dataloader, a list of ids to extract
# from a dataset, it's up to you)

# Managed by clients

# 3. Define your clients
# Define a function that returns another function that will be used during
# simulation to instantiate each individual client
# client_fn = client.<my_function_that_returns_a_function>()
client = MnistClient
def client_fn(cid: int, pool_size: int):
def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.name):
if dataset_name == "mnist":
client = MnistClient
else:
raise NotImplementedError(f"Dataset {dataset_name} not implemented")
return client(cid, pool_size)

# 4. Define your strategy
# pass all relevant argument (including the global dataset used after aggregation,
# if needed by your method.)
# strategy = instantiate(cfg.strategy, <additional arguments if desired>)
strategy = instantiate(cfg.strategy)
strategy = instantiate(
cfg.strategy,
evaluate_fn=evaluate_fn,
on_fit_config_fn=fit_config,
fraction_fit=1,
fraction_evaluate=0, # no federated evaluation
min_fit_clients=1,
min_evaluate_clients=0,
warmup_rounds=1,
to_keep=1, # Used in Flanders, MultiKrum, TrimmedMean (in Bulyan it is forced to 1)
min_available_clients=1, # All clients should be available
window=1, # Used in Flanders
sampling=1, # Used in Flanders
)


# 5. Start Simulation
# history = fl.simulation.start_simulation(<arguments for simulation>)
fl.simulation.start_simulation(
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=cfg.pool_size,
config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
client_resources={"num_cpus": 1},
num_clients=cfg.server.pool_size,
server=EnhancedServer(num_malicious=cfg.server.num_malicious, attack_fn=None, client_manager=SimpleClientManager),
config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds),
strategy=strategy
)

Expand All @@ -126,11 +149,19 @@ def client_fn(cid: int, pool_size: int):
# can retrieve the path to that directory with this:
# save_path = HydraConfig.get().runtime.output_dir

def fit_config(server_round: int) -> Dict[str, Scalar]:
"""Return a configuration with static batch size and (local) epochs."""
config = {
"epochs": 1, # number of local epochs
"batch_size": 32,
}
return config

def mnist_evaluate(
server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
):
# determine device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

model = MnistNet()
set_params(model, parameters)
Expand All @@ -140,12 +171,13 @@ def mnist_evaluate(
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1)
loss, accuracy, auc = test_mnist(model, testloader, device=device)

config["id"] = args.exp_num
#config["id"] = args.exp_num
config["round"] = server_round
config["auc"] = auc
save_results(loss, accuracy, config=config)
print(f"Round {server_round} accuracy: {accuracy} loss: {loss} auc: {auc}")

# return statistics
return loss, {"accuracy": accuracy, "auc": auc}

main()
if __name__ == "__main__":
main()
23 changes: 21 additions & 2 deletions baselines/flanders/flanders/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from flwr.server.history import History
from flwr.server.strategy import FedAvg, Strategy

from utils import (
from .utils import (
save_params,
save_predicted_params,
load_all_time_series,
Expand All @@ -57,6 +57,7 @@ class EnhancedServer(Server):

def __init__(
self,
num_malicious: int,
attack_fn:Optional[Callable],
*args: Any,
**kwargs: Any
Expand All @@ -73,14 +74,32 @@ def fit_round(
) -> Optional[
Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
]:
"""Perform a single round of federated averaging."""
"""Perform a single round of federated learning."""
# Get clients and their respective instructions from strategy
client_instructions = self.strategy.configure_fit(
server_round=server_round,
parameters=self.parameters,
client_manager=self._client_manager,
)

# Randomly decide which client is malicious
if server_round > self.warmup_rounds:
self.malicious_selected = np.random.choice(
[proxy.cid for proxy, ins in client_instructions], size=self.num_malicious, replace=False
)
log(
DEBUG,
"fit_round %s: malicious clients selected %s",
server_round,
self.malicious_selected,
)
# Save instruction for malicious clients into FitIns
for proxy, ins in client_instructions:
if proxy.cid in self.malicious_selected:
ins["malicious"] = True
else:
ins["malicious"] = False

if not client_instructions:
log(INFO, "fit_round %s: no clients selected, cancel", server_round)
return None
Expand Down
5 changes: 0 additions & 5 deletions baselines/flanders/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ ruff = "==0.0.272"
types-requests = "==2.27.7"
virtualenv = "20.21.0"

[[tool.poetry.source]]
name = "natsort"
url = "natsort-8.4.0-py3-none-any.whl"
priority = "explicit"

[tool.isort]
line_length = 88
indent = " "
Expand Down
10 changes: 10 additions & 0 deletions baselines/flanders/results/all_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
round,auc,accuracy,loss
0,0.46964516899616104,0.0442,721.8714995384216
0,0.491341331757505,0.0895,721.0930316448212
0,0.5039652172645966,0.0972,720.5661296844482
0,0.48865788458612947,0.0811,721.8517775535583
0,0.5018137886910315,0.1161,722.4481763839722
0,0.4728194180869978,0.0515,722.6484620571136
0,0.5070118697549717,0.1024,720.1501111984253
0,0.5053153492209119,0.1089,721.5759069919586
0,0.4998734219733104,0.098,721.5942440032959

0 comments on commit 2707ca2

Please sign in to comment.