From abd75357b05e62688880973afe35789dd4e4dcf6 Mon Sep 17 00:00:00 2001 From: helin1 Date: Tue, 9 Jan 2024 15:33:55 +0800 Subject: [PATCH 01/12] refactor: use alexnet to replace efficientnet --- examples/advanced-pytorch/alexnet.py | 39 ++++++++++++++++++++++++++++ examples/advanced-pytorch/client.py | 6 ++--- examples/advanced-pytorch/server.py | 3 ++- examples/advanced-pytorch/utils.py | 5 +++- 4 files changed, 48 insertions(+), 5 deletions(-) create mode 100644 examples/advanced-pytorch/alexnet.py diff --git a/examples/advanced-pytorch/alexnet.py b/examples/advanced-pytorch/alexnet.py new file mode 100644 index 000000000000..06a958e5df48 --- /dev/null +++ b/examples/advanced-pytorch/alexnet.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +class Net(nn.Module): #lr = 0.01 + def __init__(self, class_num=10): + super(Net, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + + self.classifier = nn.Sequential( + nn.Dropout(0.75), + nn.Linear(256 * 4 * 4, 4096), + nn.ReLU(inplace=True), + nn.Dropout(0.75), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, class_num), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), 256 * 4 * 4) + x = self.classifier(x) + return x \ No newline at end of file diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index f9ffb6181fd8..5f5e8f828bb1 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -24,8 +24,8 @@ def __init__( self.validation_split = validation_split def set_parameters(self, parameters): - """Loads a efficientnet model and replaces it parameters with the ones given.""" - model = utils.load_efficientnet(classes=10) + """Loads a alexnet model and replaces it parameters with the ones given.""" + model = utils.load_alexnet(classes=10) params_dict = zip(model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) model.load_state_dict(state_dict, strict=True) @@ -76,7 +76,7 @@ def evaluate(self, parameters, config): def client_dry_run(device: str = "cpu"): """Weak tests to check whether all client methods are working as expected.""" - model = utils.load_efficientnet(classes=10) + model = utils.load_alexnet(classes=10) trainset, testset = utils.load_partition(0) trainset = torch.utils.data.Subset(trainset, range(10)) testset = torch.utils.data.Subset(testset, range(10)) diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 8343e62da69f..6575d060ca18 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -88,7 +88,8 @@ def main(): args = parser.parse_args() - model = utils.load_efficientnet(classes=10) + # model = utils.load_efficientnet(classes=10) + model = utils.load_alexnet(classes=10) model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()] diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 8788ead90dee..4157fafd23c6 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,7 +1,7 @@ import torch import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 - +from alexnet import Net import warnings warnings.filterwarnings("ignore") @@ -128,3 +128,6 @@ def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = def get_model_params(model): """Returns a model's parameters.""" return [val.cpu().numpy() for _, val in model.state_dict().items()] + +def load_alexnet(): + return Net \ No newline at end of file From 4172ca650b4408fa78aa9fe004622addcbde8a50 Mon Sep 17 00:00:00 2001 From: helin1 Date: Tue, 9 Jan 2024 15:37:53 +0800 Subject: [PATCH 02/12] refactor: use alexnet to replace efficientnet v2 --- examples/advanced-pytorch/alexnet.py | 4 ++-- examples/advanced-pytorch/utils.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/advanced-pytorch/alexnet.py b/examples/advanced-pytorch/alexnet.py index 06a958e5df48..0ff988565b03 100644 --- a/examples/advanced-pytorch/alexnet.py +++ b/examples/advanced-pytorch/alexnet.py @@ -3,9 +3,9 @@ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -class Net(nn.Module): #lr = 0.01 +class AlexNet(nn.Module): #lr = 0.01 def __init__(self, class_num=10): - super(Net, self).__init__() + super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 4157fafd23c6..571829c32eba 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,7 +1,6 @@ import torch import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 -from alexnet import Net import warnings warnings.filterwarnings("ignore") @@ -129,5 +128,6 @@ def get_model_params(model): """Returns a model's parameters.""" return [val.cpu().numpy() for _, val in model.state_dict().items()] -def load_alexnet(): - return Net \ No newline at end of file +def load_alexnet(class_num): + from alexnet import AlexNet + return AlexNet(class_num) \ No newline at end of file From 1b5c3fbfd2bd65ff61745d93fda161ef708bb9ce Mon Sep 17 00:00:00 2001 From: helin1 Date: Thu, 18 Jan 2024 10:36:52 +0800 Subject: [PATCH 03/12] refactor: v3 --- examples/advanced-pytorch/README.md | 8 ++++++++ examples/advanced-pytorch/alexnet.py | 7 ++++--- examples/advanced-pytorch/client.py | 27 +++++++++++++++++++-------- examples/advanced-pytorch/server.py | 25 ++++++++++++++++++------- examples/advanced-pytorch/utils.py | 6 ++++-- 5 files changed, 53 insertions(+), 20 deletions(-) diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index db0245e41453..280639a2e85a 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -68,3 +68,11 @@ poetry run ./run.sh The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). You can also manually run `poetry run python3 server.py` and `poetry run python3 client.py` for as many clients as you want but you have to make sure that each command is ran in a different terminal window (or a different computer on the network). + + +## About Differential Privacy +If you want to achieve differential privacy, please use `alexnet` model, because efficientnet is particularly affected by noise. + +If you add a little noise when using `efficientnet`, the loss will be Nan. This is issue:https://github.com/adap/flower/issues/2342 + +Therefore, if you want to achieve differential privacy, please use `alexnet` diff --git a/examples/advanced-pytorch/alexnet.py b/examples/advanced-pytorch/alexnet.py index 0ff988565b03..4601eac7ae80 100644 --- a/examples/advanced-pytorch/alexnet.py +++ b/examples/advanced-pytorch/alexnet.py @@ -3,7 +3,8 @@ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -class AlexNet(nn.Module): #lr = 0.01 + +class AlexNet(nn.Module): # lr = 0.01 def __init__(self, class_num=10): super(AlexNet, self).__init__() self.features = nn.Sequential( @@ -32,8 +33,8 @@ def __init__(self, class_num=10): nn.Linear(4096, class_num), ) - def forward(self, x): + def forward(self, x): x = self.features(x) x = x.view(x.size(0), 256 * 4 * 4) x = self.classifier(x) - return x \ No newline at end of file + return x diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index 5f5e8f828bb1..84e942be051d 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -12,20 +12,22 @@ class CifarClient(fl.client.NumPyClient): def __init__( - self, - trainset: torchvision.datasets, - testset: torchvision.datasets, - device: str, - validation_split: int = 0.1, + self, + trainset: torchvision.datasets, + testset: torchvision.datasets, + device: str, + validation_split: int = 0.1, ): self.device = device self.trainset = trainset self.testset = testset self.validation_split = validation_split - def set_parameters(self, parameters): + def set_parameters(self, parameters, use_model): """Loads a alexnet model and replaces it parameters with the ones given.""" - model = utils.load_alexnet(classes=10) + model = utils.load_efficientnet(classes=10) + if use_model == "alexnet": + model = utils.load_alexnet(classes=10) params_dict = zip(model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) model.load_state_dict(state_dict, strict=True) @@ -35,7 +37,7 @@ def fit(self, parameters, config): """Train parameters on the locally held training set.""" # Update local model parameters - model = self.set_parameters(parameters) + model = self.set_parameters(parameters, config['use_model']) # Get hyperparameters for this round batch_size: int = config["batch_size"] @@ -126,6 +128,15 @@ def main() -> None: help="Set to true to use GPU. Default: False", ) + parser.add_argument( + "--dp", + type=str, + default="efficientnet", + required=False, + help="Use either Efficientnet or Alexnet models. \ + If you want to achieve differential privacy, please use the Alexnet model", + ) + args = parser.parse_args() device = torch.device( diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 6575d060ca18..4e4babdac1a7 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -13,7 +13,7 @@ warnings.filterwarnings("ignore") -def fit_config(server_round: int): +def fit_config(server_round: int, use_model: str): """Return training configuration dict for each round. Keep batch size fixed at 32, perform two rounds of training with one local epoch, @@ -22,6 +22,7 @@ def fit_config(server_round: int): config = { "batch_size": 16, "local_epochs": 1 if server_round < 2 else 2, + "model": use_model if use_model == "alexnet" else "efficientnet", } return config @@ -54,9 +55,9 @@ def get_evaluate_fn(model: torch.nn.Module, toy: bool): # The `evaluate` function will be called after every round def evaluate( - server_round: int, - parameters: fl.common.NDArrays, - config: Dict[str, fl.common.Scalar], + server_round: int, + parameters: fl.common.NDArrays, + config: Dict[str, fl.common.Scalar], ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: # Update model with the latest parameters params_dict = zip(model.state_dict().keys(), parameters) @@ -86,10 +87,20 @@ def main(): Useful for testing purposes. Default: False", ) + parser.add_argument( + "--use_model", + type=str, + default="efficientnet", + required=False, + help="Use either Efficientnet or Alexnet models. \ + If you want to achieve differential privacy, please use the Alexnet model", + ) + args = parser.parse_args() - # model = utils.load_efficientnet(classes=10) - model = utils.load_alexnet(classes=10) + model = utils.load_efficientnet(classes=10) + if args.use_model == "alexnet": + model = utils.load_alexnet(classes=10) model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()] @@ -101,7 +112,7 @@ def main(): min_evaluate_clients=2, min_available_clients=10, evaluate_fn=get_evaluate_fn(model, args.toy), - on_fit_config_fn=fit_config, + on_fit_config_fn=fit_config(2, args.use_model), on_evaluate_config_fn=evaluate_config, initial_parameters=fl.common.ndarrays_to_parameters(model_parameters), ) diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 571829c32eba..f959a416fc0f 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -5,6 +5,7 @@ warnings.filterwarnings("ignore") + # DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -128,6 +129,7 @@ def get_model_params(model): """Returns a model's parameters.""" return [val.cpu().numpy() for _, val in model.state_dict().items()] -def load_alexnet(class_num): + +def load_alexnet(classes): from alexnet import AlexNet - return AlexNet(class_num) \ No newline at end of file + return AlexNet(classes) From 2ad9044b2862b2b4f063e64d5d16b4b3c923343d Mon Sep 17 00:00:00 2001 From: helin1 Date: Thu, 18 Jan 2024 10:58:02 +0800 Subject: [PATCH 04/12] refactor: v4 --- examples/advanced-pytorch/client.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index 84e942be051d..197e1f7a5075 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -63,7 +63,7 @@ def fit(self, parameters, config): def evaluate(self, parameters, config): """Evaluate parameters on the locally held test set.""" # Update local model parameters - model = self.set_parameters(parameters) + model = self.set_parameters(parameters, config['use_model']) # Get config values steps: int = config["val_steps"] @@ -78,7 +78,7 @@ def evaluate(self, parameters, config): def client_dry_run(device: str = "cpu"): """Weak tests to check whether all client methods are working as expected.""" - model = utils.load_alexnet(classes=10) + model = utils.load_efficientnet(classes=10) trainset, testset = utils.load_partition(0) trainset = torch.utils.data.Subset(trainset, range(10)) testset = torch.utils.data.Subset(testset, range(10)) @@ -128,15 +128,6 @@ def main() -> None: help="Set to true to use GPU. Default: False", ) - parser.add_argument( - "--dp", - type=str, - default="efficientnet", - required=False, - help="Use either Efficientnet or Alexnet models. \ - If you want to achieve differential privacy, please use the Alexnet model", - ) - args = parser.parse_args() device = torch.device( From d0411777cafe1971d3b45318ce8e27ce948cc20c Mon Sep 17 00:00:00 2001 From: helin1 Date: Thu, 18 Jan 2024 11:02:01 +0800 Subject: [PATCH 05/12] refactor: v5 --- examples/advanced-pytorch/client.py | 2 +- examples/advanced-pytorch/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index 197e1f7a5075..1749cf882e7e 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -24,7 +24,7 @@ def __init__( self.validation_split = validation_split def set_parameters(self, parameters, use_model): - """Loads a alexnet model and replaces it parameters with the ones given.""" + """Loads a alexnet or efficientnet model and replaces it parameters with the ones given.""" model = utils.load_efficientnet(classes=10) if use_model == "alexnet": model = utils.load_alexnet(classes=10) diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 4e4babdac1a7..dac34e7571e9 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -22,7 +22,7 @@ def fit_config(server_round: int, use_model: str): config = { "batch_size": 16, "local_epochs": 1 if server_round < 2 else 2, - "model": use_model if use_model == "alexnet" else "efficientnet", + "use_model": use_model if use_model == "alexnet" else "efficientnet", } return config From e44fe86bf6234304a1e9fd94e645ecae94ed211e Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sat, 20 Jan 2024 11:37:34 +0000 Subject: [PATCH 06/12] fixes to how model is specified; other minor changes --- examples/advanced-pytorch/client.py | 36 ++++++++++++++++++----------- examples/advanced-pytorch/server.py | 21 ++++++++--------- examples/advanced-pytorch/utils.py | 3 +-- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index 3fe4ba1b0c0d..658002b00e79 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -17,28 +17,30 @@ def __init__( trainset: datasets.Dataset, testset: datasets.Dataset, device: torch.device, + model_str: str, validation_split: int = 0.1, ): self.device = device self.trainset = trainset self.testset = testset self.validation_split = validation_split + if model_str == "alexnet": + self.model = utils.load_alexnet(classes=10) + else: + self.model = utils.load_efficientnet(classes=10) - def set_parameters(self, parameters, use_model): + def set_parameters(self, parameters): """Loads a alexnet or efficientnet model and replaces it parameters with the ones given.""" - model = utils.load_efficientnet(classes=10) - if use_model == "alexnet": - model = utils.load_alexnet(classes=10) - params_dict = zip(model.state_dict().keys(), parameters) + + params_dict = zip(self.model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - model.load_state_dict(state_dict, strict=True) - return model + self.model.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): """Train parameters on the locally held training set.""" # Update local model parameters - model = self.set_parameters(parameters, config['use_model']) + self.set_parameters(parameters) # Get hyperparameters for this round batch_size: int = config["batch_size"] @@ -51,9 +53,9 @@ def fit(self, parameters, config): train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(valset, batch_size=batch_size) - results = utils.train(model, train_loader, val_loader, epochs, self.device) + results = utils.train(self.model, train_loader, val_loader, epochs, self.device) - parameters_prime = utils.get_model_params(model) + parameters_prime = utils.get_model_params(self.model) num_examples_train = len(trainset) return parameters_prime, num_examples_train, results @@ -61,7 +63,7 @@ def fit(self, parameters, config): def evaluate(self, parameters, config): """Evaluate parameters on the locally held test set.""" # Update local model parameters - model = self.set_parameters(parameters, config['use_model']) + self.set_parameters(parameters) # Get config values steps: int = config["val_steps"] @@ -69,7 +71,7 @@ def evaluate(self, parameters, config): # Evaluate global model parameters on the local test data and return results testloader = DataLoader(self.testset, batch_size=16) - loss, accuracy = utils.test(model, testloader, steps, self.device) + loss, accuracy = utils.test(self.model, testloader, steps, self.device) return float(loss), len(self.testset), {"accuracy": float(accuracy)} @@ -123,6 +125,14 @@ def main() -> None: required=False, help="Set to true to use GPU. Default: False", ) + parser.add_argument( + "--model", + type=str, + default="efficientnet", + choices=['efficientnet', 'alexnet'], + help="Use either Efficientnet or Alexnet models. \ + If you want to achieve differential privacy, please use the Alexnet model", + ) args = parser.parse_args() @@ -140,7 +150,7 @@ def main() -> None: trainset = trainset.select(range(10)) testset = testset.select(range(10)) # Start Flower client - client = CifarClient(trainset, testset, device) + client = CifarClient(trainset, testset, device, args.model) fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index c04ce17fe863..37129ff17a25 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -15,7 +15,7 @@ warnings.filterwarnings("ignore") -def fit_config(server_round: int, use_model: str): +def fit_config(server_round: int): """Return training configuration dict for each round. Keep batch size fixed at 32, perform two rounds of training with one local epoch, @@ -24,7 +24,6 @@ def fit_config(server_round: int, use_model: str): config = { "batch_size": 16, "local_epochs": 1 if server_round < 2 else 2, - "use_model": use_model if use_model == "alexnet" else "efficientnet", } return config @@ -81,33 +80,33 @@ def main(): help="Set to true to use only 10 datasamples for validation. \ Useful for testing purposes. Default: False", ) - parser.add_argument( - "--use_model", + "--model", type=str, default="efficientnet", - required=False, + choices=['efficientnet', 'alexnet'], help="Use either Efficientnet or Alexnet models. \ If you want to achieve differential privacy, please use the Alexnet model", ) args = parser.parse_args() - model = utils.load_efficientnet(classes=10) - if args.use_model == "alexnet": + if args.model == "alexnet": model = utils.load_alexnet(classes=10) + else: + model = utils.load_efficientnet(classes=10) model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()] # Create strategy strategy = fl.server.strategy.FedAvg( - fraction_fit=0.2, - fraction_evaluate=0.2, + fraction_fit=1.0, + fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, - min_available_clients=10, + min_available_clients=2, evaluate_fn=get_evaluate_fn(model, args.toy), - on_fit_config_fn=fit_config(2, args.use_model), + on_fit_config_fn=fit_config, on_evaluate_config_fn=evaluate_config, initial_parameters=fl.common.ndarrays_to_parameters(model_parameters), ) diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 90c2900817b7..0db66c825099 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,7 +1,7 @@ import torch from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop from torch.utils.data import DataLoader - +from alexnet import AlexNet import warnings from flwr_datasets import FederatedDataset @@ -133,5 +133,4 @@ def get_model_params(model): def load_alexnet(classes): - from alexnet import AlexNet return AlexNet(classes) From 9db63e7ba9bd610d00eeebf8b67e901ce0a1720c Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sat, 20 Jan 2024 11:46:40 +0000 Subject: [PATCH 07/12] w/ previous --- examples/advanced-pytorch/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 37129ff17a25..2cf08010bcde 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -100,11 +100,11 @@ def main(): # Create strategy strategy = fl.server.strategy.FedAvg( - fraction_fit=1.0, + fraction_fit=0.2, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, - min_available_clients=2, + min_available_clients=10, evaluate_fn=get_evaluate_fn(model, args.toy), on_fit_config_fn=fit_config, on_evaluate_config_fn=evaluate_config, From f5fb3f7cbb9f46316567df2a049f0db83b630bc2 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sat, 20 Jan 2024 11:47:07 +0000 Subject: [PATCH 08/12] w/ previous --- examples/advanced-pytorch/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 2cf08010bcde..0561cddd5b48 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -100,7 +100,7 @@ def main(): # Create strategy strategy = fl.server.strategy.FedAvg( - fraction_fit=0.2, + fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, From 75e4043b33c4e78068afee76920b6e0f64c3e806 Mon Sep 17 00:00:00 2001 From: helin1 Date: Wed, 24 Jan 2024 20:22:21 +0800 Subject: [PATCH 09/12] refactor: remove the the parts related to differential privacy in readme --- examples/advanced-pytorch/README.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index 26e97d84bd38..2527e8e4a820 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -69,11 +69,3 @@ but this can be changed by removing the `--toy` argument in the script. You can The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). - - -## About Differential Privacy -If you want to achieve differential privacy, please use `alexnet` model, because efficientnet is particularly affected by noise. - -If you add a little noise when using `efficientnet`, the loss will be Nan. This is issue:https://github.com/adap/flower/issues/2342 - -Therefore, if you want to achieve differential privacy, please use `alexnet` From 56326a63f13e956208709782bae6da6581efd4b7 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 24 Jan 2024 13:09:13 +0000 Subject: [PATCH 10/12] using torchvision alexnet; better logic for efficientnet; metnion model choices in readme; formatting --- examples/advanced-pytorch/README.md | 2 +- examples/advanced-pytorch/alexnet.py | 40 ------------------------- examples/advanced-pytorch/client.py | 6 ++-- examples/advanced-pytorch/run.sh | 5 ---- examples/advanced-pytorch/server.py | 8 ++--- examples/advanced-pytorch/utils.py | 44 +++++++--------------------- 6 files changed, 19 insertions(+), 86 deletions(-) delete mode 100644 examples/advanced-pytorch/alexnet.py diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index 2527e8e4a820..dd7ba56521d8 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -68,4 +68,4 @@ but this can be changed by removing the `--toy` argument in the script. You can The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). +You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`. \ No newline at end of file diff --git a/examples/advanced-pytorch/alexnet.py b/examples/advanced-pytorch/alexnet.py deleted file mode 100644 index 4601eac7ae80..000000000000 --- a/examples/advanced-pytorch/alexnet.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -import torch.nn as nn - -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class AlexNet(nn.Module): # lr = 0.01 - def __init__(self, class_num=10): - super(AlexNet, self).__init__() - self.features = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - ) - - self.classifier = nn.Sequential( - nn.Dropout(0.75), - nn.Linear(256 * 4 * 4, 4096), - nn.ReLU(inplace=True), - nn.Dropout(0.75), - nn.Linear(4096, 4096), - nn.ReLU(inplace=True), - nn.Linear(4096, class_num), - ) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), 256 * 4 * 4) - x = self.classifier(x) - return x diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index 76ff13769a24..0eb457d68645 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -1,6 +1,5 @@ import utils from torch.utils.data import DataLoader -import torchvision.datasets import torch import flwr as fl import argparse @@ -30,7 +29,8 @@ def __init__( self.model = utils.load_efficientnet(classes=10) def set_parameters(self, parameters): - """Loads a alexnet or efficientnet model and replaces it parameters with the ones given.""" + """Loads a alexnet or efficientnet model and replaces it parameters with the + ones given.""" params_dict = zip(self.model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) @@ -129,7 +129,7 @@ def main() -> None: "--model", type=str, default="efficientnet", - choices=['efficientnet', 'alexnet'], + choices=["efficientnet", "alexnet"], help="Use either Efficientnet or Alexnet models. \ If you want to achieve differential privacy, please use the Alexnet model", ) diff --git a/examples/advanced-pytorch/run.sh b/examples/advanced-pytorch/run.sh index 3367e1680535..c3d52491b987 100755 --- a/examples/advanced-pytorch/run.sh +++ b/examples/advanced-pytorch/run.sh @@ -2,11 +2,6 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the EfficientNetB0 model -python -c "import torch; torch.hub.load( \ - 'NVIDIA/DeepLearningExamples:torchhub', \ - 'nvidia_efficientnet_b0', pretrained=True)" - python server.py --toy & sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index e9f0aafd4537..489694ab1ea1 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -51,9 +51,9 @@ def get_evaluate_fn(model: torch.nn.Module, toy: bool): # The `evaluate` function will be called after every round def evaluate( - server_round: int, - parameters: fl.common.NDArrays, - config: Dict[str, fl.common.Scalar], + server_round: int, + parameters: fl.common.NDArrays, + config: Dict[str, fl.common.Scalar], ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: # Update model with the latest parameters params_dict = zip(model.state_dict().keys(), parameters) @@ -84,7 +84,7 @@ def main(): "--model", type=str, default="efficientnet", - choices=['efficientnet', 'alexnet'], + choices=["efficientnet", "alexnet"], help="Use either Efficientnet or Alexnet models. \ If you want to achieve differential privacy, please use the Alexnet model", ) diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 8c24a85f868f..186f079010dc 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,7 +1,6 @@ import torch from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop -from torch.utils.data import DataLoader -from alexnet import AlexNet +from torchvision.models import efficientnet_b0, AlexNet import warnings from flwr_datasets import FederatedDataset @@ -49,7 +48,7 @@ def train( net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD( - net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4 + net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4 ) net.train() for _ in range(epochs): @@ -99,35 +98,13 @@ def test( return loss, accuracy -def replace_classifying_layer(efficientnet_model, num_classes: int = 10): - """Replaces the final layer of the classifier.""" - num_features = efficientnet_model.classifier.fc.in_features - efficientnet_model.classifier.fc = torch.nn.Linear(num_features, num_classes) - - -def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = None): - """Loads pretrained efficientnet model from torch hub. Replaces final classifying - layer if classes is specified. - - Args: - entrypoint: EfficientNet model to download. - For supported entrypoints, please refer - https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/ - classes: Number of classes in final classifying layer. Leave as None to get - the downloaded - model untouched. - Returns: - EfficientNet Model - - Note: One alternative implementation can be found at - https://github.com/lukemelas/EfficientNet-PyTorch - """ - efficientnet = torch.hub.load( - "NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True - ) - - if classes is not None: - replace_classifying_layer(efficientnet, classes) +def load_efficientnet(classes: int = 10): + """Loads EfficienNetB0 from TorchVision.""" + efficientnet = efficientnet_b0(pretrained=True) + # Re-init output linear layer with the right number of classes + model_classes = efficientnet.classifier[1].in_features + if classes != model_classes: + efficientnet.classifier[1] = torch.nn.Linear(model_classes, classes) return efficientnet @@ -137,4 +114,5 @@ def get_model_params(model): def load_alexnet(classes): - return AlexNet(classes) + """Load AlexNet model from TorchVision.""" + return AlexNet(num_classes=classes) From ca6de6624867a1e8f5f0ab4fd4964509ad280dc7 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 24 Jan 2024 13:20:06 +0000 Subject: [PATCH 11/12] format readme --- examples/advanced-pytorch/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index dd7ba56521d8..ff2bbaac0395 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -66,6 +66,8 @@ but this can be changed by removing the `--toy` argument in the script. You can ./run.sh ``` +( $3 \times 4 = 12$ ) + The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`. \ No newline at end of file +You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`. From a514829021a7feae54978c73160800b94de74808 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 24 Jan 2024 13:49:13 +0000 Subject: [PATCH 12/12] tidy up --- examples/advanced-pytorch/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index ff2bbaac0395..9101105b2618 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -66,8 +66,6 @@ but this can be changed by removing the `--toy` argument in the script. You can ./run.sh ``` -( $3 \times 4 = 12$ ) - The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`.