diff --git a/baselines/flanders/README.md b/baselines/flanders/README.md index 1c5c8479ca26..a0ab3db89a94 100644 --- a/baselines/flanders/README.md +++ b/baselines/flanders/README.md @@ -68,9 +68,9 @@ California Housing (linear regression): | Dataset | # of clients | Clients per round | # of rounds | Batch size | Learning rate | $\lambda_1$ | $\lambda_2$ | Optimizer | Dropout | Alpha | Beta | # of clients to keep | Sampling | | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -| Income | 100 | 100 | 100 | \ | \ | 1.0 | 0.0 | CCD | \ | 0.0 | 0.0 | 1 | \ | -| MNIST | 100 | 100 | 100 | 32 | $10^{-3}$ | \ | \ | Adam | 0.2 | 0.0 | 0.0 | 1 | \ | -| California Housing | 100 | 100 | 100 | \ | \ | 0.5 | 0.5 | CCD | \ | 0.0 | 0.0 | 1 | \ | +| Income | 100 | 100 | 50 | \ | \ | 1.0 | 0.0 | CCD | \ | 0.0 | 0.0 | 1 | \ | +| MNIST | 100 | 100 | 50 | 32 | $10^{-3}$ | \ | \ | Adam | 0.2 | 0.0 | 0.0 | 1 | \ | +| California Housing | 100 | 100 | 50 | \ | \ | 0.5 | 0.5 | CCD | \ | 0.0 | 0.0 | 1 | \ | Where $\lambda_1$ and $\lambda_2$ are Lasso and Ridge regularization terms. diff --git a/baselines/flanders/flanders/__init__.py b/baselines/flanders/flanders/__init__.py index a5e567b59135..eb3edd489459 100644 --- a/baselines/flanders/flanders/__init__.py +++ b/baselines/flanders/flanders/__init__.py @@ -1 +1 @@ -"""Template baseline package.""" +"""FLANDERS package.""" diff --git a/baselines/flanders/flanders/conf/base.yaml b/baselines/flanders/flanders/conf/base.yaml index 237789bfbebd..62e0902e66c2 100644 --- a/baselines/flanders/flanders/conf/base.yaml +++ b/baselines/flanders/flanders/conf/base.yaml @@ -2,10 +2,10 @@ # this is the config that will be loaded as default by main.py # Please follow the provided structure (this will ensuring all baseline follow # a similar configuration structure and hence be easy to customise) -defaults: - - _self_ - - override hydra/hydra_logging: disabled - - override hydra/job_logging: disabled +#defaults: +# - _self_ +# - override hydra/hydra_logging: disabled +# - override hydra/job_logging: disabled hydra: output_subdir: null @@ -14,13 +14,21 @@ hydra: dataset: # dataset config + name: mnist model: # model config + _target_: flanders.models.MnistNet strategy: - _target_: # points to your strategy (either custom or exiting in Flower) - # rest of strategy config + _target_: flanders.strategy.Flanders + attack_fn: null + +server: + _target_: flanders.server.EnhancedServer + num_rounds: 2 + pool_size: 1 + num_malicious: 0 client: - # client config + # client config \ No newline at end of file diff --git a/baselines/flanders/flanders/dataset.py b/baselines/flanders/flanders/dataset.py index 859861af90f2..512ecca0408f 100644 --- a/baselines/flanders/flanders/dataset.py +++ b/baselines/flanders/flanders/dataset.py @@ -169,7 +169,7 @@ def __len__(self) -> int: return len(self.data) -def get_cifar_10(path_to_data="datasets/cifar_nn/data"): +def get_cifar_10(path_to_data="datasets/cifar_10/data"): """Downloads CIFAR10 dataset and generates a unified training set (it will be partitioned later using the LDA partitioning mechanism.""" diff --git a/baselines/flanders/flanders/main.py b/baselines/flanders/flanders/main.py index a1b205b3d72f..9c2860e26f5f 100644 --- a/baselines/flanders/flanders/main.py +++ b/baselines/flanders/flanders/main.py @@ -15,6 +15,7 @@ import numpy as np import flwr as fl from flwr.common.typing import Scalar +from flwr.server.client_manager import SimpleClientManager import torch import os import random @@ -42,6 +43,7 @@ get_sklearn_model_params ) from .utils import save_results +from .server import EnhancedServer from torch.utils.data import DataLoader from torchvision import transforms @@ -80,7 +82,8 @@ def main(cfg: DictConfig) -> None: # 1. Print parsed config print(OmegaConf.to_yaml(cfg)) - print(cfg.dataset) + print(cfg.dataset.name) + print(cfg.server.pool_size) evaluate_fn = mnist_evaluate @@ -92,27 +95,47 @@ def main(cfg: DictConfig) -> None: # be a location in the file system, a list of dataloader, a list of ids to extract # from a dataset, it's up to you) + # Managed by clients + # 3. Define your clients # Define a function that returns another function that will be used during # simulation to instantiate each individual client # client_fn = client.() - client = MnistClient - def client_fn(cid: int, pool_size: int): + def client_fn(cid: int, pool_size: int = 10, dataset_name: str = cfg.dataset.name): + if dataset_name == "mnist": + client = MnistClient + else: + raise NotImplementedError(f"Dataset {dataset_name} not implemented") return client(cid, pool_size) # 4. Define your strategy # pass all relevant argument (including the global dataset used after aggregation, # if needed by your method.) # strategy = instantiate(cfg.strategy, ) - strategy = instantiate(cfg.strategy) + strategy = instantiate( + cfg.strategy, + evaluate_fn=evaluate_fn, + on_fit_config_fn=fit_config, + fraction_fit=1, + fraction_evaluate=0, # no federated evaluation + min_fit_clients=1, + min_evaluate_clients=0, + warmup_rounds=1, + to_keep=1, # Used in Flanders, MultiKrum, TrimmedMean (in Bulyan it is forced to 1) + min_available_clients=1, # All clients should be available + window=1, # Used in Flanders + sampling=1, # Used in Flanders + ) # 5. Start Simulation # history = fl.simulation.start_simulation() - fl.simulation.start_simulation( + history = fl.simulation.start_simulation( client_fn=client_fn, - num_clients=cfg.pool_size, - config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources={"num_cpus": 1}, + num_clients=cfg.server.pool_size, + server=EnhancedServer(num_malicious=cfg.server.num_malicious, attack_fn=None, client_manager=SimpleClientManager), + config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds), strategy=strategy ) @@ -126,11 +149,19 @@ def client_fn(cid: int, pool_size: int): # can retrieve the path to that directory with this: # save_path = HydraConfig.get().runtime.output_dir +def fit_config(server_round: int) -> Dict[str, Scalar]: + """Return a configuration with static batch size and (local) epochs.""" + config = { + "epochs": 1, # number of local epochs + "batch_size": 32, + } + return config + def mnist_evaluate( server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] ): # determine device - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu") model = MnistNet() set_params(model, parameters) @@ -140,12 +171,13 @@ def mnist_evaluate( testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1) loss, accuracy, auc = test_mnist(model, testloader, device=device) - config["id"] = args.exp_num + #config["id"] = args.exp_num config["round"] = server_round config["auc"] = auc save_results(loss, accuracy, config=config) + print(f"Round {server_round} accuracy: {accuracy} loss: {loss} auc: {auc}") - # return statistics return loss, {"accuracy": accuracy, "auc": auc} -main() \ No newline at end of file +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/baselines/flanders/flanders/server.py b/baselines/flanders/flanders/server.py index 7b8177bc375e..edc059d16748 100644 --- a/baselines/flanders/flanders/server.py +++ b/baselines/flanders/flanders/server.py @@ -31,7 +31,7 @@ from flwr.server.history import History from flwr.server.strategy import FedAvg, Strategy -from utils import ( +from .utils import ( save_params, save_predicted_params, load_all_time_series, @@ -57,6 +57,7 @@ class EnhancedServer(Server): def __init__( self, + num_malicious: int, attack_fn:Optional[Callable], *args: Any, **kwargs: Any @@ -73,7 +74,7 @@ def fit_round( ) -> Optional[ Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures] ]: - """Perform a single round of federated averaging.""" + """Perform a single round of federated learning.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.configure_fit( server_round=server_round, @@ -81,6 +82,24 @@ def fit_round( client_manager=self._client_manager, ) + # Randomly decide which client is malicious + if server_round > self.warmup_rounds: + self.malicious_selected = np.random.choice( + [proxy.cid for proxy, ins in client_instructions], size=self.num_malicious, replace=False + ) + log( + DEBUG, + "fit_round %s: malicious clients selected %s", + server_round, + self.malicious_selected, + ) + # Save instruction for malicious clients into FitIns + for proxy, ins in client_instructions: + if proxy.cid in self.malicious_selected: + ins["malicious"] = True + else: + ins["malicious"] = False + if not client_instructions: log(INFO, "fit_round %s: no clients selected, cancel", server_round) return None diff --git a/baselines/flanders/pyproject.toml b/baselines/flanders/pyproject.toml index 9384975c368e..da7189902715 100644 --- a/baselines/flanders/pyproject.toml +++ b/baselines/flanders/pyproject.toml @@ -58,11 +58,6 @@ ruff = "==0.0.272" types-requests = "==2.27.7" virtualenv = "20.21.0" -[[tool.poetry.source]] -name = "natsort" -url = "natsort-8.4.0-py3-none-any.whl" -priority = "explicit" - [tool.isort] line_length = 88 indent = " " diff --git a/baselines/flanders/results/all_results.csv b/baselines/flanders/results/all_results.csv new file mode 100644 index 000000000000..17f97fc5edd0 --- /dev/null +++ b/baselines/flanders/results/all_results.csv @@ -0,0 +1,10 @@ +round,auc,accuracy,loss +0,0.46964516899616104,0.0442,721.8714995384216 +0,0.491341331757505,0.0895,721.0930316448212 +0,0.5039652172645966,0.0972,720.5661296844482 +0,0.48865788458612947,0.0811,721.8517775535583 +0,0.5018137886910315,0.1161,722.4481763839722 +0,0.4728194180869978,0.0515,722.6484620571136 +0,0.5070118697549717,0.1024,720.1501111984253 +0,0.5053153492209119,0.1089,721.5759069919586 +0,0.4998734219733104,0.098,721.5942440032959