diff --git a/baselines/fedrep/README.md b/baselines/fedrep/README.md index ece30edf0943..d67730e4065a 100644 --- a/baselines/fedrep/README.md +++ b/baselines/fedrep/README.md @@ -36,91 +36,90 @@ dataset: [CIFAR-10, CIFAR-100] These two models are modified from the [official repo](https://github.com/rahulv0205/fedrep_experiments)'s. To be clear that, in the official models, there is no BN layers. However, without BN layer helping, training will definitely collapse. -Please see how models are implemented using a so called model_manager and model_split class since FedRep uses head and base layers in a neural network. These classes are defined in the `models.py` file and thereafter called when building new models in the directory `/implemented_models`. Please, extend and add new models as you wish. +Please see how models are implemented using a so called model_manager and model_split class since FedRep uses head and base layers in a neural network. These classes are defined in the `models.py` file. Please, extend and add new models as you wish. **Dataset:** CIFAR10, CIFAR-100. CIFAR10/100 will be partitioned based on number of classes for data that each client shall receive e.g. 4 allocated classes could be [1, 3, 5, 9]. -**Training Hyperparameters:** The hyperparameters can be found in `conf/base.yaml` file which is the configuration file for the main script. - -| Description | Default Value | -| --------------------- | ----------------------------------- | -| `num_clients` | `100` | -| `num_rounds` | `100` | -| `num_local_epochs` | `5` | -| `num_rep_epochs` | `1` | -| `enable_finetune` | `False` | -| `num_finetune_epochs` | `5` | -| `use_cuda` | `true` | -| `specified_device` | `null` | -| `client resources` | `{'num_cpus': 2, 'num_gpus': 0.5 }` | -| `learning_rate` | `0.01` | -| `batch_size` | `50` | -| `model_name` | `cnncifar10` | -| `algorithm` | `fedrep` | +**Training Hyperparameters:** The hyperparameters can be found in `pyproject.toml` file under the `[tool.flwr.app.config]` section. + +| Description | Default Value | +|-------------------------|-------------------------------------| +| `num-server-rounds` | `100` | +| `num-local-epochs` | `5` | +| `num-rep-epochs` | `1` | +| `enable-finetune` | `False` | +| `num-finetune-epochs` | `5` | +| `use-cuda` | `true` | +| `specified-cuda-device` | `null` | +| `client-resources` | `{'num-cpus': 2, 'num-gpus': 0.5 }` | +| `learning-rate` | `0.01` | +| `batch-size` | `50` | +| `model-name` | `cnncifar10` | +| `algorithm` | `fedrep` | ## Environment Setup -To construct the Python environment follow these steps: +Create a new Python environment using [pyenv](https://github.com/pyenv/pyenv) and [virtualenv plugin](https://github.com/pyenv/pyenv-virtualenv), then install the baseline project: ```bash -# Set Python 3.10 -pyenv local 3.10.12 -# Tell poetry to use python 3.10 -poetry env use 3.10.12 +# Create the environment +pyenv virtualenv 3.10.12 fedrep-env -# Install the base Poetry environment -poetry install +# Activate it +pyenv activate fedrep-env -# Activate the environment -poetry shell +# Then install the project +pip install -e . ``` ## Running the Experiments ``` -python -m fedrep.main # this will run using the default settings in the `conf/base.yaml` +flwr run . # this will run using the default settings in the `pyproject.toml` ``` -While the config files contain a large number of settings, the ones below are the main ones you'd likely want to modify to . +While the config files contain a large number of settings, the ones below are the main ones you'd likely want to modify. ```bash -algorithm: fedavg, fedrep # these are currently supported -dataset.name: cifar10, cifar100 -dataset.num_classes: 2, 5, 20 (only for CIFAR-100) -model_name: cnncifar10, cnncifar100 +algorithm = "fedavg", "fedrep" # these are currently supported +dataset-name = "cifar10", "cifar100" +dataset-split-num-classes = 2, 5, 20 (only for CIFAR-100) +model-name = "cnncifar10", "cnncifar100" ``` - +See also for instance the configuration files for CIFAR10 and CIFAR100 under the `conf` directory. ## Expected Results +The default algorithm used by all configuration files is `fedrep`. To use `fedavg` please change the `algorithm` property in the respective configuration file. The default federated environment consists of 100 clients. + +When the execution completes, a new directory `results` will be created with a json file that contains the running configurations and the results per round. + +> [!NOTE] +> All plots shown below are generated using the `docs/make_plots.py` script. The script reads all json files generated by the baseline inside the `results` directory. ### CIFAR-10 (100, 2) ``` -python -m fedrep.main --config-name cifar10_100_2 algorithm=fedrep -python -m fedrep.main --config-name cifar10_100_2 algorithm=fedavg +flwr run . --run-config conf/cifar10_2.toml ``` ### CIFAR-10 (100, 5) ``` -python -m fedrep.main --config-name cifar10_100_5 algorithm=fedrep -python -m fedrep.main --config-name cifar10_100_5 algorithm=fedavg +flwr run . --run-config conf/cifar10_5.toml ``` ### CIFAR-100 (100, 5) ``` -python -m fedrep.main --config-name cifar100_100_5 algorithm=fedrep -python -m fedrep.main --config-name cifar100_100_5 algorithm=fedavg +flwr run . --run-config conf/cifar100_5.toml ``` ### CIFAR-100 (100, 20) ``` -python -m fedrep.main --config-name cifar100_100_20 algorithm=fedrep -python -m fedrep.main --config-name cifar100_100_20 algorithm=fedavg +flwr run . --run-config conf/cifar100_20.toml ``` - \ No newline at end of file + diff --git a/baselines/fedrep/_static/cifar100_100_20.png b/baselines/fedrep/_static/cifar100_100_20.png index 2421f15ac6c6..3f97d08a4dff 100644 Binary files a/baselines/fedrep/_static/cifar100_100_20.png and b/baselines/fedrep/_static/cifar100_100_20.png differ diff --git a/baselines/fedrep/_static/cifar100_100_5.png b/baselines/fedrep/_static/cifar100_100_5.png index 17f25eb480c4..f22ddcedaf61 100644 Binary files a/baselines/fedrep/_static/cifar100_100_5.png and b/baselines/fedrep/_static/cifar100_100_5.png differ diff --git a/baselines/fedrep/_static/cifar10_100_2.png b/baselines/fedrep/_static/cifar10_100_2.png index 75ee48b2c970..1e7321f85e01 100644 Binary files a/baselines/fedrep/_static/cifar10_100_2.png and b/baselines/fedrep/_static/cifar10_100_2.png differ diff --git a/baselines/fedrep/_static/cifar10_100_5.png b/baselines/fedrep/_static/cifar10_100_5.png index 1d20a953f9c4..54a9143d8fd9 100644 Binary files a/baselines/fedrep/_static/cifar10_100_5.png and b/baselines/fedrep/_static/cifar10_100_5.png differ diff --git a/baselines/fedrep/conf/cifar100_20.toml b/baselines/fedrep/conf/cifar100_20.toml new file mode 100644 index 000000000000..2b6bb5f5e1eb --- /dev/null +++ b/baselines/fedrep/conf/cifar100_20.toml @@ -0,0 +1,11 @@ +algorithm = "fedrep" + +# model specs +model-name = "cnncifar100" + +# dataset specs +dataset-name = "cifar100" +dataset-split = "sample" +dataset-split-num-classes = 20 +dataset-split-seed = 42 +dataset-split-fraction = 0.83 diff --git a/baselines/fedrep/conf/cifar100_5.toml b/baselines/fedrep/conf/cifar100_5.toml new file mode 100644 index 000000000000..9903ea374028 --- /dev/null +++ b/baselines/fedrep/conf/cifar100_5.toml @@ -0,0 +1,11 @@ +algorithm = "fedrep" + +# model specs +model-name = "cnncifar100" + +# dataset specs +dataset-name = "cifar100" +dataset-split = "sample" +dataset-split-num-classes = 5 +dataset-split-seed = 42 +dataset-split-fraction = 0.83 diff --git a/baselines/fedrep/conf/cifar10_2.toml b/baselines/fedrep/conf/cifar10_2.toml new file mode 100644 index 000000000000..307ba7101bd1 --- /dev/null +++ b/baselines/fedrep/conf/cifar10_2.toml @@ -0,0 +1,8 @@ +algorithm = "fedrep" + +# dataset specs +dataset-name = "cifar10" +dataset-split = "sample" +dataset-split-num-classes = 2 +dataset-split-seed = 42 +dataset-split-fraction = 0.83 diff --git a/baselines/fedrep/conf/cifar10_5.toml b/baselines/fedrep/conf/cifar10_5.toml new file mode 100644 index 000000000000..cd09c9e07ec2 --- /dev/null +++ b/baselines/fedrep/conf/cifar10_5.toml @@ -0,0 +1,8 @@ +algorithm = "fedrep" + +# dataset specs +dataset-name = "cifar10" +dataset-split = "sample" +dataset-split-num-classes = 5 +dataset-split-seed = 42 +dataset-split-fraction = 0.83 diff --git a/baselines/fedrep/docs/make_plots.py b/baselines/fedrep/docs/make_plots.py new file mode 100644 index 000000000000..9474e87a2471 --- /dev/null +++ b/baselines/fedrep/docs/make_plots.py @@ -0,0 +1,50 @@ +"""Generate plots from json files.""" + +import json +import os +from typing import List, Tuple + +import matplotlib.pyplot as plt + +# Get the current working directory +DIR = os.path.dirname(os.path.abspath(__file__)) + + +def read_from_results(path: str) -> Tuple[str, str, List[float], str, str]: + """Load the json file with recorded configurations and results.""" + with open(path, "r", encoding="UTF-8") as fin: + data = json.load(fin) + algorithm = data["run_config"]["algorithm"] + model = data["run_config"]["model-name"] + accuracies = [res["accuracy"] * 100 for res in data["round_res"]] + dataset = data["run_config"]["dataset-name"] + num_classes = data["run_config"]["dataset-split-num-classes"] + + return algorithm, model, accuracies, dataset, num_classes + + +def make_plot(dir_path: str, plt_title: str) -> None: + """Given a directory with json files, generated a plot using the provided title.""" + plt.figure() + with os.scandir(dir_path) as files: + for file in files: + file_name = os.path.join(dir_path, file.name) + print(file_name, flush=True) + algo, m, acc, d, n = read_from_results(file_name) + rounds = [i + 1 for i in range(len(acc))] + print(f"Max accuracy ({algo}): {max(acc):.2f}") + plt.plot(rounds, acc, label=f"{algo}-{d}-{n}classes") + plt.xlabel("Rounds") + plt.ylabel("Accuracy") + plt.title(plt_title) + plt.grid() + plt.legend() + plt.savefig(os.path.join(DIR, f"{plt_title}-{algo}")) + + +if __name__ == "__main__": + # Plot results generated by the baseline. + # Combine them into a full file path. + res_dir = os.path.join(DIR, "../results/") + title = "Federated Accuracy over Rounds" + make_plot(res_dir, plt_title=title) diff --git a/baselines/fedrep/fedrep/__init__.py b/baselines/fedrep/fedrep/__init__.py index a5e567b59135..f2dbc04ee34e 100644 --- a/baselines/fedrep/fedrep/__init__.py +++ b/baselines/fedrep/fedrep/__init__.py @@ -1 +1 @@ -"""Template baseline package.""" +"""fedrep: A Flower Baseline.""" diff --git a/baselines/fedrep/fedrep/base_model.py b/baselines/fedrep/fedrep/base_model.py index e6a74c01bf9b..82a18cec2622 100644 --- a/baselines/fedrep/fedrep/base_model.py +++ b/baselines/fedrep/fedrep/base_model.py @@ -1,61 +1,23 @@ -"""Abstract class for splitting a model into body and head.""" +"""fedrep: A Flower Baseline.""" -import os +import collections from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Dict, List, OrderedDict, Tuple, Union -import numpy as np import torch -import torch.nn as nn -from omegaconf import DictConfig -from torch import Tensor +from torch import Tensor, nn from torch.utils.data import DataLoader -from fedrep.constants import ( +from flwr.common import Context, NDArrays, ParametersRecord, array_from_numpy + +from .constants import ( DEFAULT_FINETUNE_EPOCHS, DEFAULT_LOCAL_TRAIN_EPOCHS, DEFAULT_REPRESENTATION_EPOCHS, + FEDREP_HEAD_STATE, ) -def get_device( - use_cuda: bool = True, specified_device: Optional[int] = None -) -> torch.device: - """Get the tensor device. - - Args: - use_cuda: Flag indicates whether to use CUDA or not. Defaults to True. - specified_device: Specified cuda device to use. Defaults to None. - - Raises - ------ - ValueError: Specified device not in CUDA_VISIBLE_DEVICES. - - Returns - ------- - The selected or fallbacked device. - """ - device = torch.device("cpu") - if use_cuda and torch.cuda.is_available(): - if specified_device is not None: - cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") - if cuda_visible_devices is not None: - devices = [int(d) for d in cuda_visible_devices.split(",")] - if specified_device in devices: - device = torch.device(f"cuda:{specified_device}") - else: - raise ValueError( - f"Specified device {specified_device}" - " not in CUDA_VISIBLE_DEVICES" - ) - else: - print("CUDA_VISIBLE_DEVICES not exists, using torch.device('cuda').") - else: - device = torch.device("cuda") - - return device - - class ModelSplit(ABC, nn.Module): """Abstract class for splitting a model into body and head.""" @@ -110,7 +72,7 @@ def head(self, state_dict: OrderedDict[str, Tensor]) -> None: """ self._head.load_state_dict(state_dict, strict=True) - def get_parameters(self) -> List[np.ndarray]: + def get_parameters(self) -> NDArrays: """Get model parameters. Returns @@ -164,65 +126,86 @@ class ModelManager(ABC): def __init__( self, - client_id: int, - config: DictConfig, + context: Context, trainloader: DataLoader, testloader: DataLoader, - client_save_path: Optional[str], model_split_class: Any, # ModelSplit ): """Initialize the attributes of the model manager. Args: - client_id: The id of the client. - config: Dict containing the configurations to be used by the manager. + context: The context of the current run. trainloader: Client train dataloader. testloader: Client test dataloader. - client_save_path: Path to save the client model head state. model_split_class: Class to be used to split the model into body and head \ (concrete implementation of ModelSplit). """ super().__init__() - self.config = config - self.client_id = client_id + self.context = context self.trainloader = trainloader self.testloader = testloader - self.device = get_device( - use_cuda=getattr(self.config, "use_cuda", True), - specified_device=getattr(self.config, "specified_device", None), - ) - self.client_save_path = client_save_path - self.learning_rate = config.get("learning_rate", 0.01) - self.momentum = config.get("momentum", 0.5) + self.learning_rate = self.context.run_config.get("learning-rate", 0.01) + self.momentum = self.context.run_config.get("momentum", 0.5) self._model: ModelSplit = model_split_class(self._create_model()) @abstractmethod def _create_model(self) -> nn.Module: - """Return model to be splitted into head and body.""" + """Return model to be split into head and body.""" @property def model(self) -> ModelSplit: """Return model.""" return self._model - def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]: + def _load_client_state(self) -> None: + """Load client model head state from context state; used only by FedRep.""" + # First, check if the fedrep head state is set in the context state. + if self.context.state.parameters_records.get(FEDREP_HEAD_STATE): + state_dict = collections.OrderedDict( + { + k: torch.from_numpy(v.numpy()) + for k, v in self.context.state.parameters_records[ + FEDREP_HEAD_STATE + ].items() + } + ) + # Second, check if the parameters records have values stored and load + # the state; this check is useful for the first time the model is + # tested and the head state might be empty. + if state_dict: + self._model.head.load_state_dict(state_dict) + + def _save_client_state(self) -> None: + """Save client model head state inside context state; used only by FedRep.""" + # Check if the fedrep head state is set in the context state. + if FEDREP_HEAD_STATE in self.context.state.parameters_records: + head_state = self._model.head.state_dict() + head_state_np = {k: v.detach().cpu().numpy() for k, v in head_state.items()} + head_state_arr = collections.OrderedDict( + {k: array_from_numpy(v) for k, v in head_state_np.items()} + ) + head_state_prec = ParametersRecord(head_state_arr) + self.context.state.parameters_records[FEDREP_HEAD_STATE] = head_state_prec + + def train( + self, device: torch.device + ) -> Dict[str, Union[List[Dict[str, float]], int, float]]: """Train the model maintained in self.model. Returns ------- Dict containing the train metrics. """ - # Load client state (head) if client_save_path is not None and it is not empty - if self.client_save_path is not None and os.path.isfile(self.client_save_path): - self._model.head.load_state_dict(torch.load(self.client_save_path)) + # Load state. + self._load_client_state() num_local_epochs = DEFAULT_LOCAL_TRAIN_EPOCHS - if hasattr(self.config, "num_local_epochs"): - num_local_epochs = int(self.config.num_local_epochs) + if "num-local-epochs" in self.context.run_config: + num_local_epochs = int(self.context.run_config["num-local-epochs"]) num_rep_epochs = DEFAULT_REPRESENTATION_EPOCHS - if hasattr(self.config, "num_rep_epochs"): - num_rep_epochs = int(self.config.num_rep_epochs) + if self.context.run_config["num-rep-epochs"] in self.context.run_config: + num_rep_epochs = int(self.context.run_config["num-rep-epochs"]) criterion = torch.nn.CrossEntropyLoss() weights = [v for k, v in self._model.named_parameters() if "weight" in k] @@ -238,6 +221,7 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]: correct, total = 0, 0 loss: torch.Tensor = 0.0 + self._model.to(device) self._model.train() for i in range(num_local_epochs + num_rep_epochs): if i < num_local_epochs: @@ -247,10 +231,9 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]: self._model.enable_body() self._model.disable_head() for batch in self.trainloader: - images = batch["img"] - labels = batch["label"] - outputs = self._model(images.to(self.device)) - labels = labels.to(self.device) + images = batch["img"].to(device) + labels = batch["label"].to(device) + outputs = self._model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() @@ -258,35 +241,35 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]: total += labels.size(0) correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() - # Save client state (head) - if self.client_save_path is not None: - torch.save(self._model.head.state_dict(), self.client_save_path) + # Save state. + self._save_client_state() return {"loss": loss.item(), "accuracy": correct / total} - def test(self) -> Dict[str, float]: + def test(self, device: torch.device) -> Dict[str, float]: """Test the model maintained in self.model. Returns ------- Dict containing the test metrics. """ - # Load client state (head) - if self.client_save_path is not None and os.path.isfile(self.client_save_path): - self._model.head.load_state_dict(torch.load(self.client_save_path)) + # Load state. + self._load_client_state() num_finetune_epochs = DEFAULT_FINETUNE_EPOCHS - if hasattr(self.config, "num_finetune_epochs"): - num_finetune_epochs = int(self.config.num_finetune_epochs) + if "num-finetune-epochs" in self.context.run_config: + num_finetune_epochs = int(self.context.run_config["num-finetune-epochs"]) - if num_finetune_epochs > 0 and self.config.get("enable_finetune", False): + if num_finetune_epochs > 0 and self.context.run_config.get( + "enable-finetune", False + ): optimizer = torch.optim.SGD(self._model.parameters(), lr=self.learning_rate) criterion = torch.nn.CrossEntropyLoss() self._model.train() for _ in range(num_finetune_epochs): for batch in self.trainloader: - images = batch["img"].to(self.device) - labels = batch["label"].to(self.device) + images = batch["img"] + labels = batch["label"] outputs = self._model(images) loss = criterion(outputs, labels) optimizer.zero_grad() @@ -296,11 +279,12 @@ def test(self) -> Dict[str, float]: criterion = torch.nn.CrossEntropyLoss() correct, total, loss = 0, 0, 0.0 + self._model.to(device) self._model.eval() with torch.no_grad(): for batch in self.testloader: - images = batch["img"].to(self.device) - labels = batch["label"].to(self.device) + images = batch["img"].to(device) + labels = batch["label"].to(device) outputs = self._model(images) loss += criterion(outputs, labels).item() total += labels.size(0) diff --git a/baselines/fedrep/fedrep/client.py b/baselines/fedrep/fedrep/client.py deleted file mode 100644 index f857fd2cf82a..000000000000 --- a/baselines/fedrep/fedrep/client.py +++ /dev/null @@ -1,319 +0,0 @@ -"""Client implementation - can call FedPep and FedAvg clients.""" - -from collections import OrderedDict -from pathlib import Path -from typing import Callable, Dict, List, Tuple, Type, Union - -import numpy as np -import torch -from flwr.client import Client, NumPyClient -from flwr.common import NDArrays, Scalar -from flwr_datasets import FederatedDataset -from flwr_datasets.partitioner import PathologicalPartitioner -from flwr_datasets.preprocessor import Merger -from omegaconf import DictConfig -from torch.utils.data import DataLoader -from torchvision import transforms - -from fedrep.constants import MEAN, STD, Algorithm -from fedrep.models import CNNCifar10ModelManager, CNNCifar100ModelManager - -PROJECT_DIR = Path(__file__).parent.parent.absolute() - -FEDERATED_DATASET = None - - -class BaseClient(NumPyClient): - """Implementation of Federated Averaging (FedAvg) Client.""" - - # pylint: disable=R0913 - def __init__( - self, - client_id: int, - trainloader: DataLoader, - testloader: DataLoader, - config: DictConfig, - model_manager_class: Union[ - Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager] - ], - client_state_save_path: str = "", - ): - """Initialize client attributes. - - Args: - client_id: The client ID. - trainloader: Client train data loader. - testloader: Client test data loader. - config: dictionary containing the client configurations. - model_manager_class: class to be used as the model manager. - client_state_save_path: Path for saving model head parameters. - (Just for FedRep). Defaults to "". - """ - super().__init__() - - self.client_id = client_id - self.client_state_save_path = ( - (client_state_save_path + f"/client_{self.client_id}") - if client_state_save_path != "" - else None - ) - self.model_manager = model_manager_class( - client_id=self.client_id, - config=config, - trainloader=trainloader, - testloader=testloader, - client_save_path=self.client_state_save_path, - ) - - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: - """Return the current local model parameters.""" - return self.model_manager.model.get_parameters() - - def set_parameters( - self, parameters: List[np.ndarray], evaluate: bool = False - ) -> None: - """Set the local model parameters to the received parameters. - - Args: - parameters: parameters to set the model to. - """ - _ = evaluate - model_keys = [ - k - for k in self.model_manager.model.state_dict().keys() - if k.startswith("_body") or k.startswith("_head") - ] - params_dict = zip(model_keys, parameters) - - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - - self.model_manager.model.set_parameters(state_dict) - - def perform_train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]: - """Perform local training to the whole model. - - Returns - ------- - Dict with the train metrics. - """ - self.model_manager.model.enable_body() - self.model_manager.model.enable_head() - - return self.model_manager.train() - - def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict[str, Union[bool, bytes, float, int, str]]]: - """Train the provided parameters using the locally held dataset. - - Args: - parameters: The current (global) model parameters. - config: configuration parameters for training sent by the server. - - Returns - ------- - Tuple containing the locally updated model parameters, \ - the number of examples used for training and \ - the training metrics. - """ - self.set_parameters(parameters) - - train_results = self.perform_train() - - # Update train history - print("<------- TRAIN RESULTS -------> :", train_results) - - return self.get_parameters(config), self.model_manager.train_dataset_size(), {} - - def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict[str, Union[bool, bytes, float, int, str]]]: - """Evaluate the provided global parameters using the locally held dataset. - - Args: - parameters: The current (global) model parameters. - config: configuration parameters for training sent by the server. - - Returns - ------- - Tuple containing the test loss, \ - the number of examples used for evaluation and \ - the evaluation metrics. - """ - self.set_parameters(parameters, evaluate=True) - - # Test the model - test_results = self.model_manager.test() - print("<------- TEST RESULTS -------> :", test_results) - - return ( - test_results.get("loss", 0.0), - self.model_manager.test_dataset_size(), - {k: v for k, v in test_results.items() if not isinstance(v, (dict, list))}, - ) - - -class FedRepClient(BaseClient): - """Implementation of Federated Personalization (FedRep) Client.""" - - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: - """Return the current local body parameters.""" - return [ - val.cpu().numpy() - for val in self.model_manager.model.body.state_dict().values() - ] - - def set_parameters(self, parameters: List[np.ndarray], evaluate=False) -> None: - """Set the local body parameters to the received parameters. - - Args: - parameters: parameters to set the body to. - evaluate: whether the client is evaluating or not. - """ - model_keys = [ - k - for k in self.model_manager.model.state_dict().keys() - if k.startswith("_body") - ] - - if not evaluate: - # Only update client's local head if it hasn't trained yet - model_keys.extend( - [ - k - for k in self.model_manager.model.state_dict().keys() - if k.startswith("_head") - ] - ) - - state_dict = OrderedDict( - (k, torch.from_numpy(v)) for k, v in zip(model_keys, parameters) - ) - - self.model_manager.model.set_parameters(state_dict) - - -# pylint: disable=E1101, W0603 -def get_client_fn_simulation( - config: DictConfig, client_state_save_path: str = "" -) -> Callable[[str], Client]: - """Generate the client function that creates the Flower Clients. - - Parameters - ---------- - model : DictConfig - The model configuration. - cleint_state_save_path : str - The path to save the client state. - - Returns - ------- - Tuple[Callable[[str], FlowerClient], DataLoader] - A tuple containing the client function that creates Flower Clients and - the DataLoader that will be used for testing - """ - assert config.model_name.lower() in [ - "cnncifar10", - "cnncifar100", - ], f"Model {config.model_name} not implemented" - - # - you can define your own data transformation strategy here - - # These transformations are from the official repo - train_data_transform = transforms.Compose( - [ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(MEAN[config.dataset.name], STD[config.dataset.name]), - ] - ) - test_data_transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(MEAN[config.dataset.name], STD[config.dataset.name]), - ] - ) - - use_fine_label = False - if config.dataset.name.lower() == "cifar100": - use_fine_label = True - - partitioner = PathologicalPartitioner( - num_partitions=config.num_clients, - partition_by="fine_label" if use_fine_label else "label", - num_classes_per_partition=config.dataset.num_classes, - class_assignment_mode="random", - shuffle=True, - seed=config.dataset.seed, - ) - - global FEDERATED_DATASET - if FEDERATED_DATASET is None: - FEDERATED_DATASET = FederatedDataset( - dataset=config.dataset.name.lower(), - partitioners={"all": partitioner}, - preprocessor=Merger({"all": ("train", "test")}), - ) - - def apply_train_transforms(batch): - """Apply transforms for train data to the partition from FederatedDataset.""" - batch["img"] = [train_data_transform(img) for img in batch["img"]] - if use_fine_label: - batch["label"] = batch["fine_label"] - return batch - - def apply_test_transforms(batch): - """Apply transforms for test data to the partition from FederatedDataset.""" - batch["img"] = [test_data_transform(img) for img in batch["img"]] - if use_fine_label: - batch["label"] = batch["fine_label"] - return batch - - # pylint: disable=E1101 - def client_fn(cid: str) -> Client: - """Create a Flower client representing a single organization.""" - cid_use = int(cid) - - partition = FEDERATED_DATASET.load_partition(cid_use, split="all") - - partition_train_test = partition.train_test_split( - train_size=config.dataset.fraction, shuffle=True, seed=config.dataset.seed - ) - - trainset = partition_train_test["train"].with_transform(apply_train_transforms) - testset = partition_train_test["test"].with_transform(apply_test_transforms) - - trainloader = DataLoader(trainset, config.batch_size, shuffle=True) - testloader = DataLoader(testset, config.batch_size) - - model_manager_class: Union[ - Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager] - ] - if config.model_name.lower() == "cnncifar10": - model_manager_class = CNNCifar10ModelManager - elif config.model_name.lower() == "cnncifar100": - model_manager_class = CNNCifar100ModelManager - else: - raise NotImplementedError( - f"Model {config.model_name} not implemented, check name." - ) - - if config.algorithm.lower() == Algorithm.FEDREP.value: - return FedRepClient( # type: ignore[attr-defined] - client_id=cid_use, - trainloader=trainloader, - testloader=testloader, - config=config, - model_manager_class=model_manager_class, - client_state_save_path=client_state_save_path, - ).to_client() - return BaseClient( # type: ignore[attr-defined] - client_id=cid_use, - trainloader=trainloader, - testloader=testloader, - config=config, - model_manager_class=model_manager_class, - client_state_save_path=client_state_save_path, - ).to_client() - - return client_fn diff --git a/baselines/fedrep/fedrep/client_app.py b/baselines/fedrep/fedrep/client_app.py new file mode 100644 index 000000000000..78755f4484aa --- /dev/null +++ b/baselines/fedrep/fedrep/client_app.py @@ -0,0 +1,186 @@ +"""fedrep: A Flower Baseline.""" + +from collections import OrderedDict +from typing import Dict, List, Tuple, Union + +import torch + +from flwr.client import ClientApp, NumPyClient +from flwr.client.client import Client +from flwr.common import Context, NDArrays, ParametersRecord, Scalar + +from .constants import FEDREP_HEAD_STATE, Algorithm +from .dataset import load_data +from .models import CNNCifar10ModelManager, CNNCifar100ModelManager +from .utils import get_model_manager_class + + +class BaseClient(NumPyClient): + """Implementation of Federated Averaging (FedAvg) Client.""" + + # pylint: disable=R0913 + def __init__( + self, model_manager: Union[CNNCifar10ModelManager, CNNCifar100ModelManager] + ): + """Initialize client attributes. + + Args: + model_manager: the model manager object + """ + super().__init__() + self.model_manager = model_manager + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + """Return the current local model parameters.""" + return self.model_manager.model.get_parameters() + + def set_parameters(self, parameters: NDArrays, evaluate: bool = False) -> None: + """Set the local model parameters to the received parameters. + + Args: + parameters: parameters to set the model to. + evaluate: whether to evaluate or not. + """ + _ = evaluate + model_keys = [ + k + for k in self.model_manager.model.state_dict().keys() + if k.startswith("_body") or k.startswith("_head") + ] + params_dict = zip(model_keys, parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + self.model_manager.model.set_parameters(state_dict) + + def perform_train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]: + """Perform local training to the whole model. + + Returns + ------- + Dict with the train metrics. + """ + self.model_manager.model.enable_body() + self.model_manager.model.enable_head() + + return self.model_manager.train(self.device) + + def fit( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[NDArrays, int, Dict[str, Union[bool, bytes, float, int, str]]]: + """Train the provided parameters using the locally held dataset. + + Args: + parameters: The current (global) model parameters. + config: configuration parameters for training sent by the server. + + Returns + ------- + Tuple containing the locally updated model parameters, \ + the number of examples used for training and \ + the training metrics. + """ + self.set_parameters(parameters) + self.perform_train() + + return self.get_parameters(config), self.model_manager.train_dataset_size(), {} + + def evaluate( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[float, int, Dict[str, Union[bool, bytes, float, int, str]]]: + """Evaluate the provided global parameters using the locally held dataset. + + Args: + parameters: The current (global) model parameters. + config: configuration parameters for training sent by the server. + + Returns + ------- + Tuple containing the test loss, \ + the number of examples used for evaluation and \ + the evaluation metrics. + """ + self.set_parameters(parameters, evaluate=True) + + # Test the model + test_results = self.model_manager.test(self.device) + + return ( + test_results.get("loss", 0.0), + self.model_manager.test_dataset_size(), + {k: v for k, v in test_results.items() if not isinstance(v, (dict, list))}, + ) + + +class FedRepClient(BaseClient): + """Implementation of Federated Personalization (FedRep) Client.""" + + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + """Return the current local body parameters.""" + return [ + val.cpu().numpy() + for val in self.model_manager.model.body.state_dict().values() + ] + + def set_parameters(self, parameters: NDArrays, evaluate: bool = False) -> None: + """Set the local body parameters to the received parameters. + + Args: + parameters: parameters to set the body to. + evaluate: whether the client is evaluating or not. + """ + model_keys = [ + k + for k in self.model_manager.model.state_dict().keys() + if k.startswith("_body") + ] + + if not evaluate: + # Only update client's local head if it hasn't trained yet + model_keys.extend( + [ + k + for k in self.model_manager.model.state_dict().keys() + if k.startswith("_head") + ] + ) + + state_dict = OrderedDict( + (k, torch.from_numpy(v)) for k, v in zip(model_keys, parameters) + ) + + self.model_manager.model.set_parameters(state_dict) + + +def client_fn(context: Context) -> Client: + """Construct a Client that will be run in a ClientApp.""" + model_manager_class = get_model_manager_class(context) + algorithm = str(context.run_config["algorithm"]).lower() + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + trainloader, valloader = load_data( + partition_id, num_partitions, context + ) # load the data + if algorithm == Algorithm.FEDAVG.value: + client_class = BaseClient + elif algorithm == Algorithm.FEDREP.value: + # This state variable will only be used by the FedRep algorithm. + # We only need to initialize once, since client_fn will be called + # again at every invocation of the ClientApp. + if FEDREP_HEAD_STATE not in context.state.parameters_records: + context.state.parameters_records[FEDREP_HEAD_STATE] = ParametersRecord() + client_class = FedRepClient + else: + raise RuntimeError(f"Unknown algorithm {algorithm}.") + + model_manager_obj = model_manager_class( + context=context, trainloader=trainloader, testloader=valloader + ) + + # Return client object. + client = client_class(model_manager_obj).to_client() + return client + + +# Flower ClientApp +app = ClientApp(client_fn) diff --git a/baselines/fedrep/fedrep/conf/base.yaml b/baselines/fedrep/fedrep/conf/base.yaml deleted file mode 100644 index 0d74c4fe78b6..000000000000 --- a/baselines/fedrep/fedrep/conf/base.yaml +++ /dev/null @@ -1,46 +0,0 @@ ---- -num_clients: 100 # total number of clients -num_local_epochs: 5 # number of local epochs -num_rep_epochs: 1 # number of representation epochs (only for FedRep) -enable_finetune: false -# num_finetune_epochs: 10 -batch_size: 50 -num_rounds: 100 -learning_rate: 0.01 -momentum: 0.5 -algorithm: fedrep -model_name: cnncifar10 - -client_resources: - num_cpus: 2 - num_gpus: 0.5 - -use_cuda: true -specified_device: null # the ID of cuda device, if null, then use defaults torch.device("cuda") - -dataset: - name: cifar10 - split: sample - num_classes: 2 - seed: 42 - num_clients: ${num_clients} - fraction: 0.83 - -model: - _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar10 - -fit_config: - drop_client: false - epochs: ${num_local_epochs} - batch_size: ${batch_size} - -strategy: - _target_: fedrep.strategy.FedRep - fraction_fit: 0.1 - fraction_evaluate: 0.1 - min_fit_clients: 2 - min_evaluate_clients: 2 - min_available_clients: 2 - evaluate_fn: null - on_fit_config_fn: null - on_evaluate_config_fn: null diff --git a/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml b/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml deleted file mode 100644 index 30f9fd209d58..000000000000 --- a/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml +++ /dev/null @@ -1,44 +0,0 @@ ---- -num_clients: 100 # total number of clients -num_local_epochs: 5 # number of local epochs -num_rep_epochs: 1 # number of representation epochs (only for FedRep) -enable_finetune: false -# num_finetune_epochs: 10 -batch_size: 50 -num_rounds: 100 -learning_rate: 0.01 -momentum: 0.5 -algorithm: fedrep -model_name: cnncifar100 - -client_resources: - num_cpus: 2 - num_gpus: 0.5 - -use_cuda: true -specified_device: null - -dataset: - name: cifar100 - num_classes: 20 - seed: 42 - fraction: 0.83 - -model: - _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar100 - -fit_config: - drop_client: false - epochs: ${num_local_epochs} - batch_size: ${batch_size} - -strategy: - _target_: fedrep.strategy.FedRep - fraction_fit: 0.1 - fraction_evaluate: 0.1 - min_fit_clients: 2 - min_evaluate_clients: 2 - min_available_clients: 2 - evaluate_fn: null - on_fit_config_fn: null - on_evaluate_config_fn: null diff --git a/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml b/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml deleted file mode 100644 index e0add8f03b45..000000000000 --- a/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml +++ /dev/null @@ -1,44 +0,0 @@ ---- -num_clients: 100 # total number of clients -num_local_epochs: 5 # number of local epochs -num_rep_epochs: 1 # number of representation epochs (only for FedRep) -enable_finetune: false -# num_finetune_epochs: 10 -batch_size: 50 -num_rounds: 100 -learning_rate: 0.01 -momentum: 0.5 -algorithm: fedrep -model_name: cnncifar100 - -client_resources: - num_cpus: 2 - num_gpus: 0.5 - -use_cuda: true -specified_device: null - -dataset: - name: cifar100 - num_classes: 5 - seed: 42 - fraction: 0.83 - -model: - _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar100 - -fit_config: - drop_client: false - epochs: ${num_local_epochs} - batch_size: ${batch_size} - -strategy: - _target_: fedrep.strategy.FedRep - fraction_fit: 0.1 - fraction_evaluate: 0.1 - min_fit_clients: 2 - min_evaluate_clients: 2 - min_available_clients: 2 - evaluate_fn: null - on_fit_config_fn: null - on_evaluate_config_fn: null diff --git a/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml b/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml deleted file mode 100644 index 83ee34a298ae..000000000000 --- a/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml +++ /dev/null @@ -1,44 +0,0 @@ ---- -num_clients: 100 # total number of clients -num_local_epochs: 5 # number of local epochs -num_rep_epochs: 1 # number of representation epochs (only for FedRep) -enable_finetune: false -# num_finetune_epochs: 10 -batch_size: 50 -num_rounds: 100 -learning_rate: 0.01 -momentum: 0.5 -algorithm: fedrep -model_name: cnncifar10 - -client_resources: - num_cpus: 2 - num_gpus: 0.5 - -use_cuda: true -specified_device: null - -dataset: - name: cifar10 - num_classes: 2 - seed: 42 - fraction: 0.83 - -model: - _target_: fedrep.implemented_models.cnn_cifar10.CNNCifar10 - -fit_config: - drop_client: false - epochs: ${num_local_epochs} - batch_size: ${batch_size} - -strategy: - _target_: fedrep.strategy.FedRep - fraction_fit: 0.1 - fraction_evaluate: 0.1 - min_fit_clients: 2 - min_evaluate_clients: 2 - min_available_clients: 2 - evaluate_fn: null - on_fit_config_fn: null - on_evaluate_config_fn: null diff --git a/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml b/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml deleted file mode 100644 index 0cbd104406f0..000000000000 --- a/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml +++ /dev/null @@ -1,44 +0,0 @@ ---- -num_clients: 100 # total number of clients -num_local_epochs: 5 # number of local epochs -num_rep_epochs: 1 # number of representation epochs (only for FedRep) -enable_finetune: false -# num_finetune_epochs: 10 -batch_size: 50 -num_rounds: 100 -learning_rate: 0.01 -momentum: 0.5 -algorithm: fedrep -model_name: cnncifar10 - -client_resources: - num_cpus: 2 - num_gpus: 0.5 - -use_cuda: true -specified_device: null - -dataset: - name: cifar10 - num_classes: 5 - seed: 42 - fraction: 0.83 - -model: - _target_: fedrep.implemented_models.cnn_cifar10.CNNCifar10 - -fit_config: - drop_client: false - epochs: ${num_local_epochs} - batch_size: ${batch_size} - -strategy: - _target_: fedrep.strategy.FedRep - fraction_fit: 0.1 - fraction_evaluate: 0.1 - min_fit_clients: 2 - min_evaluate_clients: 2 - min_available_clients: 2 - evaluate_fn: null - on_fit_config_fn: null - on_evaluate_config_fn: null diff --git a/baselines/fedrep/fedrep/constants.py b/baselines/fedrep/fedrep/constants.py index 27e68f2b786c..a4527e1f36d4 100644 --- a/baselines/fedrep/fedrep/constants.py +++ b/baselines/fedrep/fedrep/constants.py @@ -1,7 +1,16 @@ -"""Constants used in machine learning pipeline.""" +"""fedrep: A Flower Baseline.""" from enum import Enum +DEFAULT_LOCAL_TRAIN_EPOCHS: int = 10 +DEFAULT_FINETUNE_EPOCHS: int = 5 +DEFAULT_REPRESENTATION_EPOCHS: int = 1 +FEDREP_HEAD_STATE = "fedrep_head_state" + +MEAN = {"cifar10": [0.485, 0.456, 0.406], "cifar100": [0.507, 0.487, 0.441]} + +STD = {"cifar10": [0.229, 0.224, 0.225], "cifar100": [0.267, 0.256, 0.276]} + class Algorithm(Enum): """Algorithm names.""" @@ -10,10 +19,8 @@ class Algorithm(Enum): FEDAVG = "fedavg" -DEFAULT_LOCAL_TRAIN_EPOCHS: int = 10 -DEFAULT_FINETUNE_EPOCHS: int = 5 -DEFAULT_REPRESENTATION_EPOCHS: int = 1 +class ModelDatasetName(Enum): + """Dataset names.""" -MEAN = {"cifar10": [0.485, 0.456, 0.406], "cifar100": [0.507, 0.487, 0.441]} - -STD = {"cifar10": [0.229, 0.224, 0.225], "cifar100": [0.267, 0.256, 0.276]} + CNN_CIFAR_10 = "cnncifar10" + CNN_CIFAR_100 = "cnncifar100" diff --git a/baselines/fedrep/fedrep/dataset.py b/baselines/fedrep/fedrep/dataset.py index a616e38ae220..621baf98249a 100644 --- a/baselines/fedrep/fedrep/dataset.py +++ b/baselines/fedrep/fedrep/dataset.py @@ -1 +1,107 @@ -"""FedRep uses flwr-datasets.""" +"""fedrep: A Flower Baseline.""" + +from typing import Tuple + +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import PathologicalPartitioner +from flwr_datasets.preprocessor import Merger +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + +from flwr.common import Context + +from .constants import MEAN, STD + +FDS = None # Cache FederatedDataset + + +def load_data( + partition_id: int, num_partitions: int, context: Context +) -> Tuple[DataLoader, DataLoader]: + """Split the data and return training and testing data for the specified partition. + + Parameters + ---------- + partition_id : int + Partition number for which to retrieve the corresponding data. + num_partitions : int + Total number of partitions. + context: Context + the context of the current run. + + Returns + ------- + data : Tuple[DataLoader, DataLoader] + A tuple with the training and testing data for the current partition_id. + """ + batch_size = int(context.run_config["batch-size"]) + dataset_name = str(context.run_config["dataset-name"]).lower() + dataset_split_num_classes = int(context.run_config["dataset-split-num-classes"]) + dataset_split_seed = int(context.run_config["dataset-split-seed"]) + dataset_split_fraction = float(context.run_config["dataset-split-fraction"]) + + # - you can define your own data transformation strategy here - + # These transformations are from the official repo + train_data_transform = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(MEAN[dataset_name], STD[dataset_name]), + ] + ) + test_data_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(MEAN[dataset_name], STD[dataset_name]), + ] + ) + + use_fine_label = False + if dataset_name == "cifar100": + use_fine_label = True + + partitioner = PathologicalPartitioner( + num_partitions=num_partitions, + partition_by="fine_label" if use_fine_label else "label", + num_classes_per_partition=dataset_split_num_classes, + class_assignment_mode="random", + shuffle=True, + seed=dataset_split_seed, + ) + + global FDS # pylint: disable=global-statement + if FDS is None: + FDS = FederatedDataset( + dataset=dataset_name, + partitioners={"all": partitioner}, + preprocessor=Merger({"all": ("train", "test")}), + ) + + def apply_train_transforms(batch: Dataset) -> Dataset: + """Apply transforms for train data to the partition from FederatedDataset.""" + batch["img"] = [train_data_transform(img) for img in batch["img"]] + if use_fine_label: + batch["label"] = batch["fine_label"] + return batch + + def apply_test_transforms(batch: Dataset) -> Dataset: + """Apply transforms for test data to the partition from FederatedDataset.""" + batch["img"] = [test_data_transform(img) for img in batch["img"]] + if use_fine_label: + batch["label"] = batch["fine_label"] + return batch + + partition = FDS.load_partition(partition_id, split="all") + + partition_train_test = partition.train_test_split( + train_size=dataset_split_fraction, shuffle=True, seed=dataset_split_seed + ) + + trainset = partition_train_test["train"].with_transform(apply_train_transforms) + testset = partition_train_test["test"].with_transform(apply_test_transforms) + + trainloader = DataLoader(trainset, batch_size, shuffle=True) + testloader = DataLoader(testset, batch_size) + + return trainloader, testloader diff --git a/baselines/fedrep/fedrep/dataset_preparation.py b/baselines/fedrep/fedrep/dataset_preparation.py deleted file mode 100644 index a616e38ae220..000000000000 --- a/baselines/fedrep/fedrep/dataset_preparation.py +++ /dev/null @@ -1 +0,0 @@ -"""FedRep uses flwr-datasets.""" diff --git a/baselines/fedrep/fedrep/main.py b/baselines/fedrep/fedrep/main.py deleted file mode 100644 index 223b98aa21fa..000000000000 --- a/baselines/fedrep/fedrep/main.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Create and connect the building blocks for your experiments; start the simulation. - -It includes processioning the dataset, instantiate strategy, specify how the global -model is going to be evaluated, etc. At the end, this script saves the results. -""" - -from pathlib import Path -from typing import List, Tuple - -import flwr as fl -import hydra -from flwr.common.parameter import ndarrays_to_parameters -from flwr.common.typing import Metrics -from hydra.core.hydra_config import HydraConfig -from hydra.utils import instantiate -from omegaconf import DictConfig, OmegaConf - -from fedrep.utils import ( - get_client_fn, - get_create_model_fn, - plot_metric_from_history, - save_results_as_pickle, - set_client_state_save_path, - set_client_strategy, -) - - -@hydra.main(config_path="conf", config_name="base", version_base=None) -def main(cfg: DictConfig) -> None: - """Run the baseline. - - Parameterss - ---------- - cfg : DictConfig - An omegaconf object that stores the hydra config. - """ - # Print parsed config - print(OmegaConf.to_yaml(cfg)) - - # set client strategy - cfg = set_client_strategy(cfg) - - # Create directory to store client states if it does not exist - # Client state has subdirectories with the name of current time - client_state_save_path = set_client_state_save_path() - - # Define your clients - # Get client function - client_fn = get_client_fn(config=cfg, client_state_save_path=client_state_save_path) - - # get a function that will be used to construct the config that the client's - # fit() method will received - def get_on_fit_config(): - def fit_config_fn(server_round: int): - # resolve and convert to python dict - fit_config = OmegaConf.to_container(cfg.fit_config, resolve=True) - _ = server_round - return fit_config - - return fit_config_fn - - # get a function that will be used to construct the model - create_model, split = get_create_model_fn(cfg) - - model = split(create_model()) - - def evaluate_metrics_aggregation_fn( - eval_metrics: List[Tuple[int, Metrics]] - ) -> Metrics: - weights, accuracies = [], [] - for num_examples, metric in eval_metrics: - weights.append(num_examples) - accuracies.append(metric["accuracy"] * num_examples) - accuracy = sum(accuracies) / sum(weights) # type: ignore[arg-type] - return {"accuracy": accuracy} - - # Define your strategy - strategy = instantiate( - cfg.strategy, - initial_parameters=ndarrays_to_parameters(model.get_parameters()), - on_fit_config_fn=get_on_fit_config(), - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, - ) - - # Start Simulation - history = fl.simulation.start_simulation( - client_fn=client_fn, - num_clients=cfg.num_clients, - config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), - client_resources={ - "num_cpus": cfg.client_resources.num_cpus, - "num_gpus": cfg.client_resources.num_gpus, - }, - strategy=strategy, - ) - - # Experiment completed. Now we save the results and - # generate plots using the `history` - print("................") - print(history) - - # Save your results - save_path = Path(HydraConfig.get().runtime.output_dir) - - # save results as a Python pickle using a file_path - # the directory created by Hydra for each run - save_results_as_pickle(history, file_path=save_path) - # plot results and include them in the readme - strategy_name = strategy.__class__.__name__ - file_suffix: str = ( - f"_{strategy_name}" - f"_C={cfg.num_clients}" - f"_B={cfg.batch_size}" - f"_E={cfg.num_local_epochs}" - f"_R={cfg.num_rounds}" - f"_lr={cfg.learning_rate}" - ) - - plot_metric_from_history(history, save_path, (file_suffix)) - - -if __name__ == "__main__": - main() diff --git a/baselines/fedrep/fedrep/models.py b/baselines/fedrep/fedrep/models.py index b230f4e49766..314d7cc91688 100644 --- a/baselines/fedrep/fedrep/models.py +++ b/baselines/fedrep/fedrep/models.py @@ -1,11 +1,11 @@ -"""Model, model manager and model split for CIFAR-10 and CIFAR-100.""" +"""fedrep: A Flower Baseline.""" from typing import Tuple import torch -import torch.nn as nn +from torch import nn -from fedrep.base_model import ModelManager, ModelSplit +from .base_model import ModelManager, ModelSplit # pylint: disable=W0223 @@ -58,17 +58,12 @@ class CNNCifar10ModelManager(ModelManager): """Manager for models with Body/Head split.""" def __init__(self, **kwargs): - """Initialize the attributes of the model manager. - - Args: - client_id: The id of the client. - config: Dict containing the configurations to be used by the manager. - """ + """Initialize the attributes of the model manager.""" super().__init__(model_split_class=CNNCifar10ModelSplit, **kwargs) def _create_model(self) -> nn.Module: - """Return CNNCifar10 model to be splitted into head and body.""" - return CNNCifar10().to(self.device) + """Return CNNCifar10 model to be split into head and body.""" + return CNNCifar10() # pylint: disable=W0223 @@ -104,6 +99,11 @@ def __init__(self): self.head = nn.Sequential(nn.Linear(128, 100)) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the model.""" + x = self.body(x) + return self.head(x) + class CNNCifar100ModelSplit(ModelSplit): """Split CNNCifar100 model into body and head.""" @@ -117,14 +117,9 @@ class CNNCifar100ModelManager(ModelManager): """Manager for models with Body/Head split.""" def __init__(self, **kwargs): - """Initialize the attributes of the model manager. - - Args: - client_id: The id of the client. - config: Dict containing the configurations to be used by the manager. - """ + """Initialize the attributes of the model manager.""" super().__init__(model_split_class=CNNCifar100ModelSplit, **kwargs) def _create_model(self) -> CNNCifar100: - """Return CNNCifar100 model to be splitted into head and body.""" - return CNNCifar100().to(self.device) + """Return CNNCifar100 model to be split into head and body.""" + return CNNCifar100() diff --git a/baselines/fedrep/fedrep/server.py b/baselines/fedrep/fedrep/server.py deleted file mode 100644 index 5b0c34035ae6..000000000000 --- a/baselines/fedrep/fedrep/server.py +++ /dev/null @@ -1 +0,0 @@ -"""Server strategies pipelines for FedRep.""" diff --git a/baselines/fedrep/fedrep/server_app.py b/baselines/fedrep/fedrep/server_app.py new file mode 100644 index 000000000000..0f98a3bc8876 --- /dev/null +++ b/baselines/fedrep/fedrep/server_app.py @@ -0,0 +1,80 @@ +"""fedrep: A Flower Baseline.""" + +import json +import os +import time +from typing import Dict, List, Tuple + +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig + +from .utils import get_create_model_fn, get_server_strategy + +RESULTS_FILE = "result-{}.json" + + +def config_json_file(context: Context) -> None: + """Initialize the json file and write the run configurations.""" + # Initialize the execution results directory. + res_save_path = "./results" + if not os.path.exists(res_save_path): + os.makedirs(res_save_path) + res_save_name = time.strftime("%Y-%m-%d-%H-%M-%S") + # Set the date and full path of the file to save the results. + global RESULTS_FILE # pylint: disable=global-statement + RESULTS_FILE = RESULTS_FILE.format(res_save_name) + RESULTS_FILE = f"{res_save_path}/{RESULTS_FILE}" + data = { + "run_config": dict(context.run_config.items()), + "round_res": [], + } + with open(RESULTS_FILE, "w+", encoding="UTF-8") as fout: + json.dump(data, fout, indent=4) + + +def write_res(new_res: Dict[str, float]) -> None: + """Load the json file, append result and re-write json collection.""" + with open(RESULTS_FILE, "r", encoding="UTF-8") as fin: + data = json.load(fin) + data["round_res"].append(new_res) + + # Write the updated data back to the JSON file + with open(RESULTS_FILE, "w", encoding="UTF-8") as fout: + json.dump(data, fout, indent=4) + + +def evaluate_metrics_aggregation_fn(eval_metrics: List[Tuple[int, Metrics]]) -> Metrics: + """Weighted metrics evaluation.""" + weights, accuracies, losses = [], [], [] + for num_examples, metric in eval_metrics: + weights.append(num_examples) + accuracies.append(float(metric["accuracy"]) * num_examples) + losses.append(float(metric["loss"]) * num_examples) + accuracy = sum(accuracies) / sum(weights) + loss = sum(losses) / sum(weights) + write_res({"accuracy": accuracy, "loss": loss}) + return {"accuracy": accuracy} + + +def server_fn(context: Context) -> ServerAppComponents: + """Construct components that set the ServerApp behaviour.""" + config_json_file(context) + # Read from config + num_rounds = context.run_config["num-server-rounds"] + + # Initialize model parameters + create_model_fn, split_class = get_create_model_fn(context) + net = split_class(create_model_fn()) + parameters = ndarrays_to_parameters(net.get_parameters()) + + # Define strategy + strategy = get_server_strategy( + context=context, params=parameters, eval_fn=evaluate_metrics_aggregation_fn + ) + config = ServerConfig(num_rounds=int(num_rounds)) + + return ServerAppComponents(strategy=strategy, config=config) + + +# Create ServerApp +app = ServerApp(server_fn=server_fn) diff --git a/baselines/fedrep/fedrep/strategy.py b/baselines/fedrep/fedrep/strategy.py index 3bee45326a6f..ba4d8b75724f 100644 --- a/baselines/fedrep/fedrep/strategy.py +++ b/baselines/fedrep/fedrep/strategy.py @@ -1,4 +1,4 @@ -"""FL server strategies.""" +"""fedrep: A Flower Baseline.""" from flwr.server.strategy import FedAvg diff --git a/baselines/fedrep/fedrep/utils.py b/baselines/fedrep/fedrep/utils.py index b706ebf1e041..086abde2c46c 100644 --- a/baselines/fedrep/fedrep/utils.py +++ b/baselines/fedrep/fedrep/utils.py @@ -1,204 +1,87 @@ -"""Utility functions for FedRep.""" +"""fedrep: A Flower Baseline.""" -import logging -import os -import pickle -import time -from pathlib import Path -from secrets import token_hex -from typing import Callable, Optional, Type, Union +from typing import Callable, Type, Union -import matplotlib.pyplot as plt -import numpy as np -from flwr.client import Client -from flwr.server.history import History -from omegaconf import DictConfig +from flwr.common import Context, Parameters +from flwr.server.strategy import FedAvg -from fedrep.base_model import get_device -from fedrep.client import get_client_fn_simulation -from fedrep.constants import Algorithm -from fedrep.models import ( +from .constants import Algorithm, ModelDatasetName +from .models import ( CNNCifar10, + CNNCifar10ModelManager, CNNCifar10ModelSplit, CNNCifar100, + CNNCifar100ModelManager, CNNCifar100ModelSplit, ) - - -def set_client_state_save_path() -> str: - """Set the client state save path.""" - client_state_save_path = time.strftime("%Y-%m-%d") - client_state_sub_path = time.strftime("%H-%M-%S") - client_state_save_path = ( - f"./client_states/{client_state_save_path}/{client_state_sub_path}" - ) - if not os.path.exists(client_state_save_path): - os.makedirs(client_state_save_path) - return client_state_save_path - - -# pylint: disable=W1202 -def set_client_strategy(cfg: DictConfig) -> DictConfig: - """Set the client strategy.""" - algorithm = cfg.algorithm.lower() - if algorithm == Algorithm.FEDREP.value: - cfg.strategy["_target_"] = "fedrep.strategy.FedRep" - elif algorithm == Algorithm.FEDAVG.value: - cfg.strategy["_target_"] = "flwr.server.strategy.FedAvg" - else: - logging.warning( - "Algorithm {} not implemented. Fallback to FedAvg.".format(algorithm) - ) - return cfg - - -def get_client_fn( - config: DictConfig, client_state_save_path: str = "" -) -> Callable[[str], Client]: - """Get client function.""" - # Get algorithm - algorithm = config.algorithm.lower() - # Get client fn - if algorithm == "fedrep": - client_fn = get_client_fn_simulation( - config=config, client_state_save_path=client_state_save_path - ) - elif algorithm == "fedavg": - client_fn = get_client_fn_simulation(config=config) - else: - raise NotImplementedError - return client_fn +from .strategy import FedRep def get_create_model_fn( - config: DictConfig, + context: Context, ) -> tuple[ - Callable[[], Union[type[CNNCifar10], type[CNNCifar100]]], - Union[type[CNNCifar10ModelSplit], type[CNNCifar100ModelSplit]], + Union[Callable[[], CNNCifar10], Callable[[], CNNCifar100]], + Union[Type[CNNCifar10ModelSplit], Type[CNNCifar100ModelSplit]], ]: """Get create model function.""" - device = get_device( - use_cuda=getattr(config, "use_cuda", True), - specified_device=getattr(config, "specified_device", None), - ) - split: Union[Type[CNNCifar10ModelSplit], Type[CNNCifar100ModelSplit]] = ( - CNNCifar10ModelSplit - ) - if config.model_name.lower() == "cnncifar10": + model_name = str(context.run_config["model-name"]) + if model_name == ModelDatasetName.CNN_CIFAR_10.value: + split = CNNCifar10ModelSplit - def create_model() -> Union[Type[CNNCifar10], Type[CNNCifar100]]: + def create_model() -> CNNCifar10: # type: ignore """Create initial CNNCifar10 model.""" - return CNNCifar10().to(device) + return CNNCifar10() - elif config.model_name.lower() == "cnncifar100": + elif model_name == ModelDatasetName.CNN_CIFAR_100.value: split = CNNCifar100ModelSplit - def create_model() -> Union[Type[CNNCifar10], Type[CNNCifar100]]: + def create_model() -> CNNCifar100: # type: ignore """Create initial CNNCifar100 model.""" - return CNNCifar100().to(device) + return CNNCifar100() else: - raise NotImplementedError("Model not implemented, check name. ") + raise NotImplementedError(f"Not a recognized model name {model_name}.") return create_model, split -def plot_metric_from_history( - hist: History, save_plot_path: Path, suffix: Optional[str] = "" -) -> None: - """Plot from Flower server History. - - Parameters - ---------- - hist : History - Object containing evaluation for all rounds. - save_plot_path : Path - Folder to save the plot to. - suffix: Optional[str] - Optional string to add at the end of the filename for the plot. - """ - metric_type = "distributed" - metric_dict = ( - hist.metrics_centralized - if metric_type == "centralized" - else hist.metrics_distributed - ) - try: - _, values = zip(*metric_dict["accuracy"]) - except KeyError: # If no available metric data - return - - # let's extract decentralized loss (main metric reported in FedProx paper) - rounds_loss, values_loss = zip(*hist.losses_distributed) - - _, axs = plt.subplots(nrows=2, ncols=1, sharex="row") - axs[0].plot(np.asarray(rounds_loss), np.asarray(values_loss)) # type: ignore - axs[1].plot(np.asarray(rounds_loss), np.asarray(values)) # type: ignore - - axs[0].set_ylabel("Loss") # type: ignore - axs[1].set_ylabel("Accuracy") # type: ignore - - axs[0].grid() # type: ignore - axs[1].grid() # type: ignore - # plt.title(f"{metric_type.capitalize()} Validation - MNIST") - plt.xlabel("Rounds") - # plt.legend(loc="lower right") - - plt.savefig(Path(save_plot_path) / Path(f"{metric_type}_metrics{suffix}.png")) - plt.close() - - -def save_results_as_pickle( - history: History, - file_path: Union[str, Path], - default_filename: Optional[str] = "results.pkl", -) -> None: - """Save results from simulation to pickle. - - Parameters - ---------- - history: History - History returned by start_simulation. - file_path: Union[str, Path] - Path to file to create and store both history and extra_results. - If path is a directory, the default_filename will be used. - path doesn't exist, it will be created. If file exists, a - randomly generated suffix will be added to the file name. This - is done to avoid overwritting results. - extra_results : Optional[Dict] - A dictionary containing additional results you would like - to be saved to disk. Default: {} (an empty dictionary) - default_filename: Optional[str] - File used by default if file_path points to a directory instead - to a file. Default: "results.pkl" - """ - path = Path(file_path) - - # ensure path exists - path.mkdir(exist_ok=True, parents=True) - - def _add_random_suffix(path_: Path): - """Add a random suffix to the file name.""" - print(f"File `{path_}` exists! ") - suffix = token_hex(4) - print(f"New results to be saved with suffix: {suffix}") - return path_.parent / (path_.stem + "_" + suffix + ".pkl") - - def _complete_path_with_default_name(path_: Path): - """Append the default file name to the path.""" - print("Using default filename") - if default_filename is None: - return path_ - return path_ / default_filename - - if path.is_dir(): - path = _complete_path_with_default_name(path) - - if path.is_file(): - path = _add_random_suffix(path) - - print(f"Results will be saved into: {path}") - # data = {"history": history, **extra_results} - data = {"history": history} - # save results to pickle - with open(str(path), "wb") as handle: - pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) +def get_model_manager_class( + context: Context, +) -> Union[Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager]]: + """Depending on the model name type return the corresponding model manager.""" + model_name = str(context.run_config["model-name"]) + if model_name.lower() == ModelDatasetName.CNN_CIFAR_10.value: + model_manager_class = CNNCifar10ModelManager + elif model_name.lower() == ModelDatasetName.CNN_CIFAR_100.value: + model_manager_class = CNNCifar100ModelManager # type: ignore + else: + raise NotImplementedError( + f"Model {model_name} not implemented, please check model name." + ) + return model_manager_class + + +def get_server_strategy( + context: Context, params: Parameters, eval_fn: Callable +) -> Union[FedAvg, FedRep]: + """Define server strategy based on input algorithm.""" + algorithm = str(context.run_config["algorithm"]).lower() + if algorithm == Algorithm.FEDAVG.value: + strategy = FedAvg + elif algorithm == Algorithm.FEDREP.value: + strategy = FedRep + else: + raise RuntimeError(f"Unknown algorithm {algorithm}.") + + # Read strategy config + fraction_fit = float(context.run_config["fraction-fit"]) + fraction_evaluate = float(context.run_config["fraction-evaluate"]) + min_available_clients = int(context.run_config["min-available-clients"]) + + strategy = strategy( + fraction_fit=float(fraction_fit), + fraction_evaluate=fraction_evaluate, + min_available_clients=min_available_clients, + initial_parameters=params, + evaluate_metrics_aggregation_fn=eval_fn, + ) # type: ignore + return strategy # type: ignore diff --git a/baselines/fedrep/pyproject.toml b/baselines/fedrep/pyproject.toml index e4c3551af19a..9a45199032ef 100644 --- a/baselines/fedrep/pyproject.toml +++ b/baselines/fedrep/pyproject.toml @@ -1,73 +1,39 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "fedrep" version = "1.0.0" -description = "Exploiting Shared Representations for Personalized Federated Learning" +description = "" license = "Apache-2.0" -authors = ["Jiahao Tan "] -readme = "README.md" -homepage = "https://flower.ai" -repository = "https://github.com/adap/flower" -documentation = "https://flower.ai" -classifiers = [ - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Operating System :: MacOS :: MacOS X", - "Operating System :: POSIX :: Linux", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: Implementation :: CPython", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - "Typing :: Typed", +dependencies = [ + "flwr[simulation]>=1.13.1", + "flwr-datasets[vision]>=0.4.0", + "torch==2.2.1", + "torchvision==0.17.1", ] -[tool.poetry.dependencies] -python = ">=3.10.0, <3.11.0" # don't change this -flwr = { extras = ["simulation"], version = "1.9.0" } -hydra-core = "1.3.2" # don't change this -pandas = "^2.2.2" -matplotlib = "^3.9.0" -tqdm = "^4.66.4" -torch = "^2.2.2" -torchvision = "^0.17.2" -setuptools = "<70" -flwr-datasets = { extras = ["vision"], version = ">=0.3.0" } - -[tool.poetry.dev-dependencies] -isort = "==5.13.2" -black = "==24.2.0" -docformatter = "==1.7.5" -mypy = "==1.4.1" -pylint = "==2.8.2" -flake8 = "==3.9.2" -pytest = "==6.2.4" -pytest-watch = "==4.2.0" -ruff = "==0.0.272" -types-requests = "==2.27.7" -virtualenv = "==20.21.0" +[tool.hatch.metadata] +allow-direct-references = true + +[project.optional-dependencies] +dev = [ + "isort==5.13.2", + "black==24.2.0", + "docformatter==1.7.5", + "mypy==1.8.0", + "pylint==3.2.6", + "flake8==5.0.4", + "pytest==6.2.4", + "pytest-watch==4.2.0", + "ruff==0.1.9", + "types-requests==2.31.0.20240125", +] [tool.isort] -line_length = 88 -indent = " " -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true +profile = "black" +known_first_party = ["flwr"] [tool.black] line-length = 88 @@ -76,7 +42,9 @@ target-version = ["py38", "py39", "py310", "py311"] [tool.pytest.ini_options] minversion = "6.2" addopts = "-qq" -testpaths = ["flwr_baselines"] +testpaths = [ + "flwr_baselines", +] [tool.mypy] ignore_missing_imports = true @@ -84,14 +52,22 @@ strict = false plugins = "numpy.typing.mypy_plugin" [tool.pylint."MESSAGES CONTROL"] -good-names = "i,j,k,_,x,y,X,Y" -signature-mutators = "hydra.main.main" +disable = "duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y,K,N" +max-args = 10 +max-attributes = 15 +max-locals = 36 +max-branches = 20 +max-statements = 55 -[tool.pylint."TYPECHECK"] +[tool.pylint.typecheck] generated-members = "numpy.*, torch.*, tensorflow.*" [[tool.mypy.overrides]] -module = ["importlib.metadata.*", "importlib_metadata.*"] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] follow_imports = "skip" follow_imports_for_stubs = true disallow_untyped_calls = false @@ -137,3 +113,49 @@ exclude = [ [tool.ruff.pydocstyle] convention = "numpy" + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "dimitris" + +[tool.flwr.app.components] +serverapp = "fedrep.server_app:app" +clientapp = "fedrep.client_app:app" + +[tool.flwr.app.config] +algorithm = "fedrep" + +# dataset specs +dataset-name = "cifar10" +dataset-split = "sample" +dataset-split-num-classes = 2 +dataset-split-seed = 42 +dataset-split-fraction = 0.83 + +# model specs +model-name = "cnncifar10" +batch-size = 50 +learning-rate = 0.01 +momentum = 0.5 +enable-finetune = false +num-finetune-epochs = 5 +num-local-epochs = 5 # number of local epochs +num-rep-epochs = 1 # number of representation epochs (only for FedRep) + +# server specs +num-server-rounds = 100 +fraction-fit = 0.1 +fraction-evaluate = 0.1 +min-available-clients = 2 +min-evaluate-clients = 2 +min-fit-clients = 2 + +[tool.flwr.federations] +default = "local-sim-100" + +[tool.flwr.federations.local-sim-100] +options.num-supernodes = 100 +options.backend.client-resources.num-cpus = 2 +options.backend.client-resources.num-gpus = 0.5 # GPU fraction allocated to each client