diff --git a/README.md b/README.md index 002d16066e78..4f148d2445d9 100644 --- a/README.md +++ b/README.md @@ -97,8 +97,6 @@ Flower Baselines is a collection of community-contributed experiments that repro - [MNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist) - [FedProx](https://arxiv.org/abs/1812.06127): - [MNIST](https://github.com/adap/flower/tree/main/baselines/fedprox/) -- [FedBN: Federated Learning on non-IID Features via Local Batch Normalization](https://arxiv.org/abs/2102.07623): - - [Convergence Rate](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate) - [Adaptive Federated Optimization](https://arxiv.org/abs/2003.00295): - [CIFAR-10/100](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization) diff --git a/baselines/README.md b/baselines/README.md index 3b30ff1a9eaf..a18c0553b2b4 100644 --- a/baselines/README.md +++ b/baselines/README.md @@ -1,14 +1,14 @@ # Flower Baselines -> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedProx](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedprox_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), [FedBN](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). +> We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines in the `flwr_baselines` directory. Currently, you can make use of baselines for [FedAvg](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/fedavg_mnist), [FedOpt](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/adaptive_federated_optimization), and [LEAF-FEMNIST](https://github.com/adap/flower/tree/main/baselines/flwr_baselines/flwr_baselines/publications/leaf/femnist). > The documentation below has been updated to reflect the new way of using Flower baselines. ## Structure -Each baseline in this directory is fully self-contained in terms of source code it its own directory. In addition, each baseline uses its very own Python environment as designed by the contributors of such baseline in order to replicate the experiments in the paper. Each baseline directory contains the following structure: +Each baseline in this directory is fully self-contained in terms of source code in its own directory. In addition, each baseline uses its very own Python environment as designed by the contributors of such baseline in order to replicate the experiments in the paper. Each baseline directory contains the following structure: ```bash baselines// diff --git a/baselines/doc/source/how-to-use-baselines.rst b/baselines/doc/source/how-to-use-baselines.rst index b89c17b17a98..ed47438ad5a9 100644 --- a/baselines/doc/source/how-to-use-baselines.rst +++ b/baselines/doc/source/how-to-use-baselines.rst @@ -3,7 +3,7 @@ Use Baselines .. warning:: We are changing the way we structure the Flower baselines. While we complete the transition to the new format, you can still find the existing baselines and use them: `baselines (old) `_. - Currently, you can make use of baselines for `FedAvg `_, `FedProx `_, `FedOpt `_, `FedBN `_, and `LEAF-FEMNIST `_. + Currently, you can make use of baselines for `FedAvg `_, `FedOpt `_, and `LEAF-FEMNIST `_. The documentation below has been updated to reflect the new way of using Flower baselines. diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/README.md b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/README.md deleted file mode 100644 index 2b718afd7dd4..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/README.md +++ /dev/null @@ -1,189 +0,0 @@ -# FedBN Baseline - Convergence Rate - -## Experiment Introduction - -The **FedBN - Convergence Rate** baseline is based on the paper [FEDBN: FEDERATED LEARNING ON NON-IID FEATURES VIA LOCAL BATCH NORMALIZATION](https://arxiv.org/pdf/2102.07623.pdf) and reproduces the results presented in *Chapter 5 - Convergence Rate (Figure 3)*. The implementation is based on the Flower federated learning framework. This experiment uses 5 completely different image datasets of digits to emulate a non-IID data distribution over the different clients. The experiment therefore uses 5 clients for the training. The local training is set up to perform 1 epoch and it uses a CNN model together with the SGD optimizer (cross-entropy loss). - -## Dataset - -### General Overview - -The following 5 different datasets are used to simulate a non-IID data distribution over 5 clients: - -* [MNIST](https://ieeexplore.ieee.org/document/726791) -* [MNIST-M]((https://arxiv.org/pdf/1505.07818.pdf)) -* [SVHN](http://ufldl.stanford.edu/housenumbers/nips2011_housenumbers.pdf) -* [USPS](https://ieeexplore.ieee.org/document/291440) -* [SynthDigits](https://arxiv.org/pdf/1505.07818.pdf) - -A more detailed explanation of the datasets is given in the following table. - -| | MNIST | MNIST-M | SVHN | USPS | SynthDigits | -|--- |--- |--- |--- |--- |--- | -| data type| handwritten digits| MNIST modification randomly colored with colored patches| Street view house numbers | handwritten digits from envelopes by the U.S. Postal Service | Syntehtic digits Windows TM font varying the orientation, blur and stroke colors | -| color | greyscale | RGB | RGB | greyscale | RGB | -| pixelsize | 28x28 | 28 x 28 | 32 x32 | 16 x16 | 32 x32 | -| labels | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | -| number of trainset | 60.000 | 60.000 | 73.257 | 9,298 | 50.000 | -| number of testset| 10.000 | 10.000 | 26.032 | - | - | -| image shape | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) | - -### Dataset Download - -The [FedBN](https://arxiv.org/pdf/2102.07623.pdf) authors prepared a preprocessed dataset on their GitHub repository. It is available [here](https://github.com/med-air/FedBN). Please download the dataset, save it in a `/data` directory, and unzip it. - -The training data contains only 7438 samples and it is split into 10 files, but only one file is used for the **FedBN: Convergence Rate** baseline. Therefore, 743 samples are used for the local training. - -Run the following commands to download and preprocess the original data: - -```bash -# download data (will create a directory in ./path) -python utils/data_download.py - -# preprocess -python utils/data_preprocess.py -``` - -All the datasets (with the exception of SynthDigits) can be downloaded from the original sources: - -```bash -# download -python utils/data_download_raw.py - -# preprocess -python utils/data_preprocess.py -``` - -## Training Setup - -### CNN Architecture - -The CNN architecture is detailed in the paper and used to create the **FedBN - Convergence Rate** baseline. - -| Layer | Details| -| ----- | ------ | -| 1 | Conv2D(3,64, 5,1,2)
BN(64), ReLU, MaxPool2D(2,2) | -| 2 | Conv2D(64, 64, 5, 1, 2)
BN(64), ReLU, MaxPool2D(2,2) | -| 3 | Conv2D(64, 128, 5, 1, 2)
BN(128), ReLU | -| 4 | FC(6272, 2048)
BN(2048), ReLU | -| 5 | FC(2048, 512)
BN(512), ReLU | -| 6 | FC(512, 10) | - -### Training Paramaters - -| Description | Value | -| ----------- | ----- | -| training samples | 743 | -| mu | 10E-2 | -| local epochs | 1 | -| loss | cross entropy loss | -| optimizer | SGD | - -## Running the Experiment - -Before running any part of the experiment, [please get the required data](dataset-download), and place it in the `/data` directory. - -As soon as you have downloaded the data you are ready to start the baseline experiment. The baseline implementation is split across different files: - -* cnn_model.py -* client.py -* server.py -* run.sh -* utils/data_utils.py -* evaluation_plot.py - -In order to run the experiment, you simply make `run.sh` executable and run it: - -```bash -chmod +x run.sh -# train the CNN with FedAvg strategy -./run.sh fedavg -# train the CNN with FedBN strategy -./run.sh fedbn -``` - -The `run.sh` script first creates the files in which the evaluation results are saved and starts `server.py` and 5 clients in parallel with `client.py`. As explained before, each client loads a different dataset. The clients save the evaluation results after the parameters were sent from the server to the client and right before the local training. The saved parameters are included in a dict with the following information: - -```python -test_dict = {"dataset": self.num_examples["dataset"], "fl_round" : fl_round, "strategy": self.mode, "test_loss": loss, "test_accuracy": accuracy} -``` - -The `utils/data_utils.py` script prepares/loads the data for the training and `cnn_model.py` defines the [CNN model architecture](#cnn-architecture). This baseline only uses one single file with 743 samples from the downloaded dataset. - -If you want to compare the results of both strategy runs (FedAvg and FedBN) you can run: - -```bash -# create the evalutation plot -python evaluation_plot.py -``` - -This will create a plot `convergence_rate.png` including the train loss after the server aggregation from a certain client for the FedAvg and FedBN stretegy. A noticeable difference between both stretagies is visible for the SVHN dataset. The train loss has a more stable and steeper decrease when using FedBN over FedAvg. - -![Illustration of convergence rate on the SVHN client trained with Flower using FedBN and FedAvg.](convergence_rate_FedBN_FedAvg_comparison.png) - -This baseline was created to reproduce the [FedBN implementation available on GitHub](https://github.com/med-air/FedBN). The following diagram shows the convergence rate of the SVHN client with same strategy, FedBN, using the Flower implementation and the original FedBN implementation. The loss decrease is very similar in both cases. - -![Illustration of convergence rate on the SVHN client trained with Flower and the traditional FedBN code.](convergence_rate_Flower_FedBN_comparison.png) - -### Server - -This baseline compares Federated Averaging (FedAvg) with Federated Batch Normalization (FedBN). In both cases, we are using FedAvg on the server-side. All model parameters are sent from the client to the server and aggregated. However, in the case of FedBN, we are setting up the client to exclude batch norm layers from the transmission to the server. FedBN is therefore a strategy that only requires changes on the client-side. - -The server is kept very simple and remains the same in both settings. We are using FedAvg on the server-side with the parameters `min_fit_clients`, `min_eval_clients`, and `min_available_clients`. All are set to the value `5` since we have five clients to be trained and evaluated in each FL round. All in all, the *FedBN* paper runs 600 FL rounds that can be implemented as follows. - -```python -import flwr as fl - -if __name__ == "__main__": - strategy = fl.server.strategy.FedAvg( - min_fit_clients=5, - min_eval_clients=5, - min_available_clients=5, - ) - fl.server.start_server(server_address="[::]:8080", config={"num_rounds": 100}, strategy=strategy) - -``` - -### Client - -The client is a little more complex. However, it can be separated into different parts. The main parts are: - -* `load_partition()` - * load the right dataset -* `train()` - * perfom the local training -* `test()` - * evaluate the training results -* `FlowerClient(fl.client.NumPyClient)` - * start the Flower client -* `main()` - * start all previous process in a main() file - -The `load_partition()` function loads the datasets saved in the `/data` dierctory. - -You can directly see that the training and evaluation process is defined within the client. We are using PyTorch to train and evaluate the model with the parameters described in the section [Training Setup](#training-setup). - -The Flower client `FlowerClient(fl.client.NumPyClient)` implements the usual methods: - -* get_paramaters() -* set_parameters() -* fit() -* evaluate() - -Let us take a closer look at `set_parameters()` in order to understand the difference between FedAvg and FedBN: - -```python -def set_parameters(self, parameters: List[np.ndarray])-> None: - self.model.train() - if self.mode == 'fedbn': - keys = [k for k in self.model.state_dict().keys() if "bn" not in k] - params_dict = zip(keys, parameters) - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - self.model.load_state_dict(state_dict, strict=False) - else: - params_dict = zip(self.model.state_dict().keys(), parameters) - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - self.model.load_state_dict(state_dict, strict=True) -``` - -You can see that the clients update all layers of the PyTorch model with the values received from the server when using the plain FedAvg strategy (including batch norm layers). However, in the case of FedBN, the parameters for the batch norm layers (`bn`) are excluded. diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/__init__.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/client.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/client.py deleted file mode 100644 index db2ab647e1ba..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/client.py +++ /dev/null @@ -1,402 +0,0 @@ -"""FedBN client.""" - - -import argparse -import json -from collections import OrderedDict -from typing import Dict, Tuple - -import flwr as fl -import torch -from flwr.common.typing import NDArrays, Scalar -from torch import nn -from torchvision import transforms - -from .utils.cnn_model import CNNModel -from .utils.data_utils import DigitsDataset - -FL_ROUND = 0 - -eval_list = [] - - -# pylint: disable=no-member -DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# pylint: enable=no-member - - -# mypy: allow-any-generics -# pylint: disable= too-many-arguments, too-many-locals, global-statement -class FlowerClient(fl.client.NumPyClient): - """Flower client implementing image classification using PyTorch.""" - - def __init__( - self, - model: CNNModel, - trainloader: torch.utils.data.DataLoader, - testloader: torch.utils.data.DataLoader, - num_examples: Dict, - mode: str, - ) -> None: - self.model = model - self.trainloader = trainloader - self.testloader = testloader - self.num_examples = num_examples - self.mode = mode - - def get_parameters(self, config) -> NDArrays: - """Return model parameters as a list of NumPy ndarrays w or w/o using - BN layers.""" - self.model.train() - # pylint: disable = no-else-return - if self.mode == "fedbn": - # Excluding parameters of BN layers when using FedBN - return [ - val.cpu().numpy() - for name, val in self.model.state_dict().items() - if "bn" not in name - ] - else: - # Return all model parameters as a list of NumPy ndarrays - return [val.cpu().numpy() for _, val in self.model.state_dict().items()] - - def set_parameters(self, parameters: NDArrays) -> None: - """Set model parameters from a list of NumPy ndarrays Exclude the bn - layer if available.""" - self.model.train() - # pylint: disable=not-callable - if self.mode == "fedbn": - keys = [k for k in self.model.state_dict().keys() if "bn" not in k] - params_dict = zip(keys, parameters) - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - self.model.load_state_dict(state_dict, strict=False) - else: - params_dict = zip(self.model.state_dict().keys(), parameters) - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - self.model.load_state_dict(state_dict, strict=True) - # pylint: enable=not-callable - - def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict]: - """Set model parameters, train model, return updated model - parameters.""" - self.set_parameters(parameters) - test_loss, test_accuracy = test( - self.model, self.num_examples["dataset"], self.trainloader, device=DEVICE - ) - test_dict = { - "dataset": self.num_examples["dataset"], - "fl_round": FL_ROUND, - "strategy": self.mode, - "train_loss": test_loss, - "train_accuracy": test_accuracy, - } - loss, accuracy = train( - self.model, - self.trainloader, - self.num_examples["dataset"], - epochs=1, - device=DEVICE, - ) - eval_list.append(test_dict) - return ( - self.get_parameters({}), - self.num_examples["trainset"], - {"loss": loss, "accuracy": accuracy}, - ) - - def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict]: - """Set model parameters, evaluate model on local test dataset, return - result.""" - self.set_parameters(parameters) - global FL_ROUND - loss, accuracy = test( - self.model, self.num_examples["dataset"], self.testloader, device=DEVICE - ) - test_dict = { - "dataset": self.num_examples["dataset"], - "fl_round": FL_ROUND, - "strategy": self.mode, - "test_loss": loss, - "test_accuracy": accuracy, - } - eval_list.append(test_dict) - FL_ROUND += 1 - return ( - float(loss), - self.num_examples["testset"], - {"loss": loss, "accuracy": accuracy}, - ) - - -def load_partition( - dataset: str, -) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]: - """Load 'MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST-M' for the training - and test data to simulate a partition.""" - - if dataset == "MNIST": - print(f"Load {dataset} dataset") - - transform = transforms.Compose( - [ - transforms.Grayscale(num_output_channels=3), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - trainset = DigitsDataset( - data_path="data/MNIST", - channels=1, - percent=0.1, - train=True, - transform=transform, - ) - testset = DigitsDataset( - data_path="data/MNIST", - channels=1, - percent=0.1, - train=False, - transform=transform, - ) - - elif dataset == "SVHN": - print(f"Load {dataset} dataset") - - transform = transforms.Compose( - [ - transforms.Resize([28, 28]), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - trainset = DigitsDataset( - data_path="data/SVHN", - channels=3, - percent=0.1, - train=True, - transform=transform, - ) - testset = DigitsDataset( - data_path="data/SVHN", - channels=3, - percent=0.1, - train=False, - transform=transform, - ) - - elif dataset == "USPS": - print(f"Load {dataset} dataset") - - transform = transforms.Compose( - [ - transforms.Resize([28, 28]), - transforms.Grayscale(num_output_channels=3), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - trainset = DigitsDataset( - data_path="data/USPS", - channels=1, - percent=0.1, - train=True, - transform=transform, - ) - testset = DigitsDataset( - data_path="data/USPS", - channels=1, - percent=0.1, - train=False, - transform=transform, - ) - - elif dataset == "SynthDigits": - print(f"Load {dataset} dataset") - - transform = transforms.Compose( - [ - transforms.Resize([28, 28]), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - trainset = DigitsDataset( - data_path="data/SynthDigits/", - channels=3, - percent=0.1, - train=True, - transform=transform, - ) - testset = DigitsDataset( - data_path="data/SynthDigits/", - channels=3, - percent=0.1, - train=False, - transform=transform, - ) - - elif dataset == "MNIST-M": - print(f"Load {dataset} dataset") - - transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - trainset = DigitsDataset( - data_path="data/MNIST_M/", - channels=3, - percent=0.1, - train=True, - transform=transform, - ) - testset = DigitsDataset( - data_path="data/MNIST_M/", - channels=3, - percent=0.1, - train=False, - transform=transform, - ) - - else: - print("No valid dataset available") - - num_examples = { - "dataset": dataset, - "trainset": len(trainset), - "testset": len(testset), - } - - print(f"Loaded {dataset} dataset with {num_examples} samples. Good Luck!") - - trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True) - testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False) - - return trainloader, testloader, num_examples - - -def train(model, traindata, dataset, epochs, device) -> Tuple[float, float]: - """Train the network.""" - # Define loss and optimizer - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) - - print( - f"Training {dataset} dataset with {epochs} local epoch(s) w/ {len(traindata)} batches each" - ) - - # Train the network - model.to(device) - model.train() - for epoch in range(epochs): # loop over the dataset multiple times - running_loss = 0.0 - total = 0.0 - correct = 0 - for i, data in enumerate(traindata, 0): - images, labels = data[0].to(device), data[1].to(device) - - # zero the parameter gradients - optimizer.zero_grad() - - # forward + backward + optimize - outputs = model(images) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - # print statistics - running_loss += loss.item() - _, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member - total += labels.size(0) - correct += (predicted == labels).sum().item() - loss = running_loss - accuracy = correct / total - if i == len(traindata) - 1: # print every 100 mini-batches - accuracy = correct / total - loss_batch = running_loss / len(traindata) - print( - f"Train Dataset {dataset} with [{epoch+1}, {i+1}] \ - loss: {loss_batch} accuracy: {accuracy}" - ) - running_loss = 0.0 - loss = loss / len(traindata) - return loss, accuracy - - -def test(model, dataset, testdata, device) -> Tuple[float, float]: - """Validate the network on the entire test set.""" - # Define loss and metrics - criterion = nn.CrossEntropyLoss() - correct = 0 - total = 0 - loss = 0.0 - - # Evaluate the network - model.to(device) - model.eval() - with torch.no_grad(): - for data in testdata: - images, labels = data[0].to(device), data[1].to(device) - outputs = model(images) - loss += criterion(outputs, labels).item() - _, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member - total += labels.size(0) - correct += (predicted == labels).sum().item() - accuracy = correct / total - loss = loss / len(testdata) - print(f"Dataset {dataset} with evaluation loss: {loss}") - return loss, accuracy - - -def main() -> None: - """Load data, start FlowerClient.""" - - # Parse command line argument `partition` (type of dataset) and `mode` (type of strategy) - parser = argparse.ArgumentParser(description="Flower") - parser.add_argument( - "--partition", - type=str, - choices=["MNIST", "SVHN", "USPS", "SynthDigits", "MNIST-M"], - required=True, - ) - parser.add_argument( - "--mode", - type=str, - choices=["fedbn", "fedavg"], - required=True, - default="fedbn", - ) - args = parser.parse_args() - - # Load model - model = CNNModel().to(DEVICE).train() - - # Load data - trainloader, testloader, num_examples = load_partition(args.partition) - - # Perform a single forward pass to properly initialize BatchNorm - _ = model(next(iter(trainloader))[0].to(DEVICE)) - - # Start client - client = FlowerClient(model, trainloader, testloader, num_examples, args.mode) - print("Start client of dataset", num_examples["dataset"]) - fl.client.start_numpy_client(server_address="[::]:8000", client=client) - # Save train and evaluation loss and accuracy in json file - with open( - f"results/{args.partition}_{args.mode}_results.json", mode="r+" - ) as eval_file: - json.dump(eval_list, eval_file) - - -if __name__ == "__main__": - main() diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/convergence_rate_FedBN_FedAvg_comparison.png b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/convergence_rate_FedBN_FedAvg_comparison.png deleted file mode 100644 index 2c9219f9078b..000000000000 Binary files a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/convergence_rate_FedBN_FedAvg_comparison.png and /dev/null differ diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/convergence_rate_Flower_FedBN_comparison.png b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/convergence_rate_Flower_FedBN_comparison.png deleted file mode 100644 index d80398ec0331..000000000000 Binary files a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/convergence_rate_Flower_FedBN_comparison.png and /dev/null differ diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/evaluation_plot.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/evaluation_plot.py deleted file mode 100644 index 38429d002f02..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/evaluation_plot.py +++ /dev/null @@ -1,47 +0,0 @@ -"""This code takes the evalution results from the directory results/ and -creates a plot to compare e.g. fedbn with fedavg.""" - - -import json - -import matplotlib.pyplot as plt # type: ignore - -fedavg_step_number = [] -fedavg_loss = [] -fedbn_step_number = [] -fedbn_loss = [] - - -def get_evaluation_numbers() -> None: - """Open the json files to get the evaluation results.""" - with open("results/SVHN_fedavg_results.json") as fedavg_file: - fedavg_list = json.load(fedavg_file) - with open("results/SVHN_fedbn_results.json") as fedbn_file: - fedbn_list = json.load(fedbn_file) - for item in fedavg_list: - if "train_loss" in item: - fedavg_step_number.append(item["fl_round"]) - fedavg_loss.append(item["train_loss"]) - for item in fedbn_list: - if "train_loss" in item: - fedbn_step_number.append(item["fl_round"]) - fedbn_loss.append(item["train_loss"]) - - -def main() -> None: - """Plot evaluation results.""" - get_evaluation_numbers() - # pylint: disable= unused-variable, invalid-name - fig, ax = plt.subplots() - fedavg = ax.plot(fedavg_step_number, fedavg_loss, label="FedAvg") - fedbn = ax.plot(fedbn_step_number, fedbn_loss, label="FedBN") - ax.legend() - plt.axis([-3, 100, -0.1, 2.5]) - plt.ylabel("Training loss") - plt.xlabel("Number of FL round") - plt.title("SVHN") - plt.savefig("convergence_rate.png") - - -if __name__ == "__main__": - main() diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/run.sh b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/run.sh deleted file mode 100755 index 2a291781419d..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/run.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -mode=${1:-"fedavg"} - -echo "Start Training with $mode" -mkdir -p results - -python3 server.py & -sleep 5 # Sleep for 5s to give the server enough time to start - -for i in 'MNIST' 'SVHN' 'USPS' 'SynthDigits' 'MNIST-M'; do - touch "results/${i}_"$mode"_results.json" - sleep 5 -done - -for i in 'MNIST' 'SVHN' 'USPS' 'SynthDigits' 'MNIST-M' ; do - echo "Starting client $i" - python3 client.py --partition=${i} --mode="$mode" & - sleep 5 & -done - -# This will allow you to use CTRL+C to stop all background processes -trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM -# Wait for all background processes to complete -wait diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/server.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/server.py deleted file mode 100644 index cadb6713a24a..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/server.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Flower server example.""" - - -import flwr as fl - -if __name__ == "__main__": - strategy = fl.server.strategy.FedAvg( - min_fit_clients=5, - min_evaluate_clients=5, - min_available_clients=5, - ) - fl.server.start_server( - server_address="[::]:8000", - config=fl.server.ServerConfig(num_rounds=2), - strategy=strategy, - ) diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/__init__.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/cnn_model.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/cnn_model.py deleted file mode 100644 index 99bb80a3d116..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/cnn_model.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Define the model architecture.""" - - -from torch import nn - - -# pylint: disable=unsubscriptable-object,too-many-instance-attributes -class CNNModel(nn.Module): - """Model for benchmark experiment on Digits.""" - - def __init__(self, num_classes=10): - super().__init__() - self.conv1 = nn.Conv2d(3, 64, 5, 1, 2) - self.bn1 = nn.BatchNorm2d(64) - self.conv2 = nn.Conv2d(64, 64, 5, 1, 2) - self.bn2 = nn.BatchNorm2d(64) - self.conv3 = nn.Conv2d(64, 128, 5, 1, 2) - self.bn3 = nn.BatchNorm2d(128) - - self.fc1 = nn.Linear(6272, 2048) - self.bn4 = nn.BatchNorm1d(2048) - self.fc2 = nn.Linear(2048, 512) - self.bn5 = nn.BatchNorm1d(512) - self.fc3 = nn.Linear(512, num_classes) - - # pylint: disable=arguments-differ,invalid-name - def forward(self, x): - """Forward pass.""" - x = nn.functional.relu(self.bn1(self.conv1(x))) - x = nn.functional.max_pool2d(x, 2) - - x = nn.functional.relu(self.bn2(self.conv2(x))) - x = nn.functional.max_pool2d(x, 2) - - x = nn.functional.relu(self.bn3(self.conv3(x))) - - x = x.view(x.shape[0], -1) - - x = self.fc1(x) - x = self.bn4(x) - x = nn.functional.relu(x) - - x = self.fc2(x) - x = self.bn5(x) - x = nn.functional.relu(x) - - x = self.fc3(x) - return x diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_download.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_download.py deleted file mode 100644 index 5909b2d410a4..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_download.py +++ /dev/null @@ -1,67 +0,0 @@ -"""This code will download the required datasets from a Google Drive, save it -in the directory ./data/data.zip and extracts it.""" - - -import os -import zipfile -from pathlib import Path -from typing import Any - -import requests -from tqdm import tqdm # type: ignore - -# pylint: disable=invalid-name - - -def download_file_from_google_drive(file_id: Any, destination: str) -> None: - """Download zip from Google drive.""" - url = "https://docs.google.com/uc?export=download" - - session = requests.Session() - - response = session.get(url, params={"id": file_id}, stream=True) - token = get_confirm_token(response) - - if token: - params = {"id": file_id, "confirm": token} - print("ID", params["id"]) - response = session.get(url, params=params, stream=True) - total_length = response.headers.get("content-length") - print("Downloading...") - save_response_content(response, destination, total_length) # type:ignore - print("Dowload done") - - -# pylint: enable=invalid-name - - -def get_confirm_token(response: Any) -> Any: - """Conform Google cookies.""" - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - return value - return None - - -def save_response_content(response: Any, destination: str, total_length: float) -> None: - """save data in the given data file.""" - chunk_size = 32768 - - with open(destination, "wb") as download_file: - total_length = int(total_length) - for chunk in tqdm( - response.iter_content(chunk_size), total=int(total_length / chunk_size) - ): - if chunk: # filter out keep-alive new chunks - download_file.write(chunk) - - -if __name__ == "__main__": - Path("./data").mkdir(exist_ok=True) - FILE_ID = "1P8g7uHyVxQJPcBKE8TAzfdKbimpRbj0I" - DESTINATION = "data/data.zip" - download_file_from_google_drive(FILE_ID, DESTINATION) - print("Extracting...") - with zipfile.ZipFile(DESTINATION, "r") as zip_ref: - for file in tqdm(iterable=zip_ref.namelist(), total=len(zip_ref.namelist())): - zip_ref.extract(member=file, path=os.path.dirname(DESTINATION)) diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_download_raw.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_download_raw.py deleted file mode 100644 index 369c943d7884..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_download_raw.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Download the raw datasets for MNIST, USPS, SVHN Create the MNISTM from MNIST -And download the Synth Data (Syntehtic digits Windows TM font varying the -orientation, blur and stroke colors). - -This dataset is already processed. -""" - - -import gzip -import pickle -import shutil -from pathlib import Path -from typing import Dict - -import numpy as np -import torch -import wget # type: ignore -from torchvision import datasets # type: ignore - -from .mnistm import create_mnistm # type: ignore - -# pylint: disable=invalid-name - - -def decompress(infile, tofile): - """Take data file and unzip it.""" - - with open(infile, "rb") as inf, open(tofile, "w", encoding="utf8") as tof: - decom_str = gzip.decompress(inf.read()).decode("utf-8") - tof.write(decom_str) - - -def download_all(data: Dict, out_dir: Path): - """Downloading datasets.""" - - for k, v in data.items(): - print(f"Downloading: {k}\n") - wget.download(v, out=str(out_dir / k)) - - -def get_synthDigits(out_dir: Path): - """get synth dataset.""" - - if out_dir.exists(): - print(f"Directory ({out_dir}) exists, skipping downloading SynthDigits.") - return - - # pylint: disable=line-too-long - out_dir.mkdir() - data = {} - data[ - "synth_train_32x32.mat" - ] = "https://github.com/domainadaptation/datasets/blob/master/synth/synth_train_32x32.mat?raw=true" - data[ - "synth_test_32x32.mat" - ] = "https://github.com/domainadaptation/datasets/blob/master/synth/synth_test_32x32.mat?raw=true" - download_all(data, out_dir) - # pylint: disable=line-too-long - - # How to proceed? It seems these `.mat` have no data. URLs found here: - # https://domainadaptation.org/api/salad.datasets.digits.html#module-salad.datasets.digits.synth - - -def get_MNISTM(out_dir: Path): - """Creates MNISTM dataset as done by https://github.com/pumpikano/tf- - dann#build-mnist-m-dataset.""" - # steps = 'https://github.com/pumpikano/tf-dann#build-mnist-m-dataset' - if out_dir.exists(): - print(f"> Directory ({out_dir}) exists, skipping downloading MNISTM.") - return - - out_dir.mkdir() - data = {} - data[ - "BSR_bsds500.tgz" - ] = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" - download_all(data, out_dir) - - train = torch.load("./data/MNIST/training.pt") - test = torch.load("./data/MNIST/test.pt") - print("Building train set...") - train_labels = train[1] - train = create_mnistm.create_mnistm(train[0]) - print("Building test set...") - test_labels = test[1] - test = create_mnistm.create_mnistm(test[0]) - val_labels = np.zeros(0) - val = np.zeros([0, 28, 28, 3], np.uint8) - - # Save dataset as pickle - with open(out_dir / "mnistm_data.pkl", "wb") as f: - pickle.dump( - { - "train": train, - "train_label": train_labels, - "test": test, - "test_label": test_labels, - "valid": val, - "valid_label": val_labels, - }, - f, - pickle.HIGHEST_PROTOCOL, - ) - - -def get_USPS(out_dir: Path): - """get USPS data (handwritten digits from envelopes by the U.S. - - Postal Service) - """ - - if out_dir.exists(): - print(f"> Directory ({out_dir}) exists, skipping downloading USPS.") - return - - out_dir.mkdir() - data = {} - data[ - "usps.bz2" - ] = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2" - data[ - "usps.t.bz2" - ] = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2" - - download_all(data, out_dir) - - -def get_SVHN(out_dir: Path): - """Get SVHN dataset (Street view house numbers)""" - if out_dir.exists(): - print(f"> Directory ({out_dir}) exists, skipping downloading SVHN.") - return - - out_dir.mkdir() - - data = {} - data["train_32x32.mat"] = "http://ufldl.stanford.edu/housenumbers/train_32x32.mat" - data["test_32x32.mat"] = "http://ufldl.stanford.edu/housenumbers/test_32x32.mat" - - download_all(data, out_dir) - - -def get_MNIST(out_dir: Path): - """Downloads MNIST using torchvision routines. - - Then, move processed files to directory expected by - `utils/data_processing.py`. Delete the rest. - """ - - if (out_dir / "MNIST").exists(): - print(f"> Directory ({out_dir}) exists, skipping downloading MNIST.") - print(type(out_dir)) - return - - datasets.MNIST(out_dir, train=True, download=True) - - datasets.MNIST(out_dir, train=False) - - train_file = "training.pt" - test_file = "test.pt" - shutil.move(out_dir / "MNIST" / "processed" / train_file, out_dir / "MNIST" / train_file) # type: ignore[arg-type] - shutil.move(out_dir / "MNIST" / "processed" / test_file, out_dir / "MNIST" / test_file) # type: ignore[arg-type] - shutil.rmtree(out_dir / "MNIST" / "raw") - shutil.rmtree(out_dir / "MNIST" / "processed") - - -def main(): - """Get all the datasets.""" - - data_dir = Path("./data") - - data_dir.mkdir(exist_ok=True) - - get_MNIST(data_dir) # type: ignore[arg-type] - - get_SVHN(data_dir / "SVHN") - - get_USPS(data_dir / "USPS") - - get_MNISTM(data_dir / "MNIST_M") - - get_synthDigits(data_dir / "SynthDigits") - - -if __name__ == "__main__": - main() diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_preprocess.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_preprocess.py deleted file mode 100644 index df2cd3b89b46..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_preprocess.py +++ /dev/null @@ -1,278 +0,0 @@ -"""This file is used to download and pre-process all data in Digit-5 dataset. - -i.e., splitted data into train&test set in a stratified way. The -function to process data into 10 partitions is also provided. -""" - - -import bz2 -import os -import pickle as pkl -from collections import Counter - -import numpy as np -import scipy.io as scio # type: ignore -import torch -from sklearn import model_selection # type: ignore - -# pylint: disable=invalid-name, too-many-locals, broad-except - - -def stratified_split(X, y): - """Provides train/test indices to split data in train/test sets.""" - - sss = model_selection.StratifiedShuffleSplit( - n_splits=1, test_size=0.2, random_state=0 - ) - - for train_index, test_index in sss.split(X, y): - X_train, X_test = X[train_index], X[test_index] - y_train, y_test = y[train_index], y[test_index] - print("Train:", Counter(y_train)) - print("Test:", Counter(y_test)) - - return (X_train, y_train), (X_test, y_test) - - -def process_mnist(): - """ - train: - (56000, 28, 28) - (56000,) - test: - (14000, 28, 28) - (14000,) - """ - mnist_train = "./data/MNIST/training.pt" - mnist_test = "./data/MNIST/test.pt" - train = torch.load(mnist_train) - test = torch.load(mnist_test) - - train_img = train[0].numpy() - train_tar = train[1].numpy() - - test_img = test[0].numpy() - test_tar = test[1].numpy() - - all_img = np.concatenate([train_img, test_img]) - all_tar = np.concatenate([train_tar, test_tar]) - - train_stratified, test_stratified = stratified_split(all_img, all_tar) - print("# After spliting:") - print("Train imgs:\t", train_stratified[0].shape) - print("Train labels:\t", train_stratified[1].shape) - print("Test imgs:\t", test_stratified[0].shape) - print("Test labels:\t", test_stratified[1].shape) - - with open("./data/MNIST/train.pkl", "wb") as mnist_train_file: - pkl.dump(train_stratified, mnist_train_file, pkl.HIGHEST_PROTOCOL) - - with open("./data/MNIST/test.pkl", "wb") as mnist_test_file: - pkl.dump(test_stratified, mnist_test_file, pkl.HIGHEST_PROTOCOL) - - -def process_svhn(): - """ - train: - (79431, 32, 32, 3) - (79431,) - test: - (19858, 32, 32, 3) - (19858,) - """ - train = scio.loadmat("./data/SVHN/train_32x32.mat") - test = scio.loadmat("./data/SVHN/test_32x32.mat") - - train_img = train["X"] - train_tar = train["y"].astype(np.int64).squeeze() - - test_img = test["X"] - test_tar = test["y"].astype(np.int64).squeeze() - - train_img = np.transpose(train_img, (3, 0, 1, 2)) - test_img = np.transpose(test_img, (3, 0, 1, 2)) - - np.place(train_tar, train_tar == 10, 0) - np.place(test_tar, test_tar == 10, 0) - - all_img = np.concatenate([train_img, test_img]) - all_tar = np.concatenate([train_tar, test_tar]) - - train_stratified, test_stratified = stratified_split(all_img, all_tar) - print("# After spliting:") - print("Train imgs:\t", train_stratified[0].shape) - print("Train labels:\t", train_stratified[1].shape) - print("Test imgs:\t", test_stratified[0].shape) - print("Test labels:\t", test_stratified[1].shape) - - with open("./data/SVHN/train.pkl", "wb") as svhn_train_file: - pkl.dump(train_stratified, svhn_train_file, pkl.HIGHEST_PROTOCOL) - - with open("./data/SVHN/test.pkl", "wb") as svhn_test_file: - pkl.dump(test_stratified, svhn_test_file, pkl.HIGHEST_PROTOCOL) - - -def process_usps(): - """ - train: - (7438, 16, 16) - (7438,) - test: - (1860, 16, 16) - (1860,) - :return: - """ - - train_path = "./data/USPS/usps.bz2" - with bz2.open(train_path) as fp: - raw_data = [l.decode().split() for l in fp.readlines()] - imgs = [[x.split(":")[-1] for x in data[1:]] for data in raw_data] - imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16)) - imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) - targets = [int(d[0]) - 1 for d in raw_data] - - train_img = imgs - train_tar = np.array(targets) - - test_path = "./data/USPS/usps.t.bz2" - with bz2.open(test_path) as fp: - raw_data = [l.decode().split() for l in fp.readlines()] - imgs = [[x.split(":")[-1] for x in data[1:]] for data in raw_data] - imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16)) - imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) - targets = [int(d[0]) - 1 for d in raw_data] - - test_img = imgs - test_tar = np.array(targets) - - all_img = np.concatenate([train_img, test_img]) - all_tar = np.concatenate([train_tar, test_tar]) - - train_stratified, test_stratified = stratified_split(all_img, all_tar) - print("# After spliting:") - print("Train imgs:\t", train_stratified[0].shape) - print("Train labels:\t", train_stratified[1].shape) - print("Test imgs:\t", test_stratified[0].shape) - print("Test labels:\t", test_stratified[1].shape) - - with open("./data/USPS/train.pkl", "wb") as usps_train_file: - pkl.dump(train_stratified, usps_train_file, pkl.HIGHEST_PROTOCOL) - - with open("./data/USPS/test.pkl", "wb") as usps_test_file: - pkl.dump(test_stratified, usps_test_file, pkl.HIGHEST_PROTOCOL) - - -def process_synth(): - """(391162, 32, 32, 3) (391162,) (97791, 32, 32, 3) (97791,)""" - train = scio.loadmat("./data/SynthDigits/synth_train_32x32.mat") - test = scio.loadmat("./data/SynthDigits/synth_test_32x32.mat") - - train_img = train["X"] - train_tar = train["y"].astype(np.int64).squeeze() - - test_img = test["X"] - test_tar = test["y"].astype(np.int64).squeeze() - - train_img = np.transpose(train_img, (3, 0, 1, 2)) - test_img = np.transpose(test_img, (3, 0, 1, 2)) - - all_img = np.concatenate([train_img, test_img]) - all_tar = np.concatenate([train_tar, test_tar]) - - train_stratified, test_stratified = stratified_split(all_img, all_tar) - print("# After spliting:") - print("Train imgs:\t", train_stratified[0].shape) - print("Train labels:\t", train_stratified[1].shape) - print("Test imgs:\t", test_stratified[0].shape) - print("Test labels:\t", test_stratified[1].shape) - - with open("./data/SynthDigits/train.pkl", "wb") as synthdigits_train_file: - pkl.dump(train_stratified, synthdigits_train_file, pkl.HIGHEST_PROTOCOL) - - with open("./data/SynthDigits/test.pkl", "wb") as synthdigits_test_file: - pkl.dump(test_stratified, synthdigits_test_file, pkl.HIGHEST_PROTOCOL) - - -def process_mnistm(): - """ - (56000, 28, 28, 3) - (56000,) - (14000, 28, 28, 3) - (14000,) - :return: - """ - data = np.load("./data/MNIST_M/mnistm_data.pkl", allow_pickle=True) - train_img = data["train"] - train_tar = data["train_label"] - valid_img = data["valid"] - valid_tar = data["valid_label"] - test_img = data["test"] - test_tar = data["test_label"] - - all_img = np.concatenate([train_img, valid_img, test_img]) - all_tar = np.concatenate([train_tar, valid_tar, test_tar]) - - train_stratified, test_stratified = stratified_split(all_img, all_tar) - print("# After spliting:") - print("Train imgs:\t", train_stratified[0].shape) - print("Train labels:\t", train_stratified[1].shape) - print("Test imgs:\t", test_stratified[0].shape) - print("Test labels:\t", test_stratified[1].shape) - - with open("./data/MNIST_M/train.pkl", "wb") as mnistm_train_file: - pkl.dump(train_stratified, mnistm_train_file, pkl.HIGHEST_PROTOCOL) - - with open("./data/MNIST_M/test.pkl", "wb") as mnistm_test_file: - pkl.dump(test_stratified, mnistm_test_file, pkl.HIGHEST_PROTOCOL) - - -def split(data_path, percentage=0.1): - """split each single dataset into multiple partitions for client scaling - training each part remain the same size according to the smallest datasize - (i.e. 743)""" - images, labels = np.load(os.path.join(data_path, "train.pkl"), allow_pickle=True) - part_len = 743.8 - part_num = int(1.0 / percentage) - - for num in range(part_num): - images_part = images[int(part_len * num) : int(part_len * (num + 1)), :, :] - labels_part = labels[int(part_len * num) : int(part_len * (num + 1))] - - save_path = os.path.join(data_path, "partitions") - if not os.path.exists(save_path): - os.makedirs(save_path) - with open( - os.path.join(save_path, "train_part{}.pkl".format(num)), "wb" - ) as file_split: - pkl.dump((images_part, labels_part), file_split, pkl.HIGHEST_PROTOCOL) - - -if __name__ == "__main__": - print("Processing...") - print("--------MNIST---------") - process_mnist() - print("--------SVHN---------") - process_svhn() - print("--------USPS---------") - process_usps() - print("--------MNIST-M---------") - process_mnistm() - print("-------SynthDigits-------") - try: - process_synth() - except Exception as exception: - print(f"unable to process SynthDigits: {exception}") - - base_paths = [ - "./data/MNIST", - "./data/SVHN", - "./data/USPS", - "./data/MNIST_M", - "./data/SynthDigits", - ] - for path in base_paths: - print(f"Spliting {os.path.basename(path)}") - try: - split(path) - except Exception as exception: - print(f"Failed to split: {path} --> {exception}") diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_utils.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_utils.py deleted file mode 100644 index c465e3ea10d1..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/data_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -"""This code creates 10 different partitions of each datasets.""" - - -import os -import sys - -import numpy as np -from PIL import Image # type: ignore -from torch.utils.data import Dataset - -base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(base_path) - - -class DigitsDataset(Dataset): - """Split datasets.""" - - # pylint: disable=too-many-arguments - def __init__( - self, - data_path, - channels, - percent=0.1, - filename=None, - train=True, - transform=None, - ): - if filename is None: - if train: - if percent >= 0.1: - for part in range(int(percent * 10)): - if part == 0: - self.images, self.labels = np.load( - os.path.join( - data_path, - f"partitions/train_part{part}.pkl", - ), - allow_pickle=True, - ) - else: - images, labels = np.load( - os.path.join( - data_path, - f"partitions/train_part{part}.pkl", - ), - allow_pickle=True, - ) - self.images = np.concatenate([self.images, images], axis=0) - self.labels = np.concatenate([self.labels, labels], axis=0) - else: - self.images, self.labels = np.load( - os.path.join(data_path, "partitions/train_part0.pkl"), - allow_pickle=True, - ) - data_len = int(self.images.shape[0] * percent * 10) - self.images = self.images[:data_len] - self.labels = self.labels[:data_len] - else: - self.images, self.labels = np.load( - os.path.join(data_path, "test.pkl"), allow_pickle=True - ) - else: - self.images, self.labels = np.load( - os.path.join(data_path, filename), allow_pickle=True - ) - - self.transform = transform - self.channels = channels - self.labels = self.labels.astype(np.long).squeeze() - - def __len__(self): - return self.images.shape[0] - - def __getitem__(self, idx): - image = self.images[idx] - label = self.labels[idx] - if self.channels == 1: - image = Image.fromarray(image, mode="L") - elif self.channels == 3: - image = Image.fromarray(image, mode="RGB") - else: - raise ValueError(f"{self.channels} channel is not allowed.") - - if self.transform is not None: - image = self.transform(image) - - return image, label diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/mnistm/__init__.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/mnistm/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/mnistm/create_mnistm.py b/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/mnistm/create_mnistm.py deleted file mode 100644 index 5af23d52122a..000000000000 --- a/baselines/flwr_baselines/flwr_baselines/publications/fedbn/convergence_rate/utils/mnistm/create_mnistm.py +++ /dev/null @@ -1,71 +0,0 @@ -"""! This script has been borrowed and adapted. Original script: -https://github.com/pumpikano/tf-dann/blob/master/create_mnistm.py. - -It creatse the MNIST-M dataset based on MNIST -""" - - -import tarfile -from typing import Any - -import numpy as np -import skimage # type: ignore -import skimage.io # type: ignore -import skimage.transform # type: ignore - - -# pylint: disable=invalid-name, disable=no-member, bare-except -def compose_image(digit: Any, background: Any) -> Any: - """Difference-blend a digit and a random patch from a background image.""" - w, h, _ = background.shape - dw, dh, _ = digit.shape - x = np.random.randint(0, w - dw) - y = np.random.randint(0, h - dh) - - bg = background[x : x + dw, y : y + dh] - return np.abs(bg - digit).astype(np.uint8) - - -def mnist_to_img(x: Any) -> Any: - """Binarize MNIST digit and convert to RGB.""" - x = (x > 0).float() - d = x.reshape([28, 28, 1]) * 255 - return np.concatenate([d, d, d], 2) - - -def create_mnistm(X: Any) -> Any: - """Give an array of MNIST digits, blend random background patches to build - the MNIST-M dataset as described in - http://jmlr.org/papers/volume17/15-239/15-239.pdf.""" - - bst_path = "./data/MNIST_M/BSR_bsds500.tgz" - - rand = np.random.RandomState(42) - train_files = [] - - with tarfile.open(bst_path, "r") as bsr_file: - for name in bsr_file.getnames(): - if name.startswith("BSR/BSDS500/data/images/train/"): - train_files.append(name) - - print("Loading BSR training images") - background_data = [] - for name in train_files: - try: - fp = bsr_file.extractfile(name) - bg_img = skimage.io.imread(fp) - background_data.append(bg_img) - except: - continue - - X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8) - for i in range(X.shape[0]): - if i % 1000 == 0: - print("Processing example", i) - - bg_img = rand.choice(background_data) - d = mnist_to_img(X[i]) - d = compose_image(d, bg_img) - X_[i] = d - - return X_