diff --git a/baselines/fedrep/README.md b/baselines/fedrep/README.md
index ece30edf0943..d67730e4065a 100644
--- a/baselines/fedrep/README.md
+++ b/baselines/fedrep/README.md
@@ -36,91 +36,90 @@ dataset: [CIFAR-10, CIFAR-100]
These two models are modified from the [official repo](https://github.com/rahulv0205/fedrep_experiments)'s. To be clear that, in the official models, there is no BN layers. However, without BN layer helping, training will definitely collapse.
-Please see how models are implemented using a so called model_manager and model_split class since FedRep uses head and base layers in a neural network. These classes are defined in the `models.py` file and thereafter called when building new models in the directory `/implemented_models`. Please, extend and add new models as you wish.
+Please see how models are implemented using a so called model_manager and model_split class since FedRep uses head and base layers in a neural network. These classes are defined in the `models.py` file. Please, extend and add new models as you wish.
**Dataset:** CIFAR10, CIFAR-100. CIFAR10/100 will be partitioned based on number of classes for data that each client shall receive e.g. 4 allocated classes could be [1, 3, 5, 9].
-**Training Hyperparameters:** The hyperparameters can be found in `conf/base.yaml` file which is the configuration file for the main script.
-
-| Description | Default Value |
-| --------------------- | ----------------------------------- |
-| `num_clients` | `100` |
-| `num_rounds` | `100` |
-| `num_local_epochs` | `5` |
-| `num_rep_epochs` | `1` |
-| `enable_finetune` | `False` |
-| `num_finetune_epochs` | `5` |
-| `use_cuda` | `true` |
-| `specified_device` | `null` |
-| `client resources` | `{'num_cpus': 2, 'num_gpus': 0.5 }` |
-| `learning_rate` | `0.01` |
-| `batch_size` | `50` |
-| `model_name` | `cnncifar10` |
-| `algorithm` | `fedrep` |
+**Training Hyperparameters:** The hyperparameters can be found in `pyproject.toml` file under the `[tool.flwr.app.config]` section.
+
+| Description | Default Value |
+|-------------------------|-------------------------------------|
+| `num-server-rounds` | `100` |
+| `num-local-epochs` | `5` |
+| `num-rep-epochs` | `1` |
+| `enable-finetune` | `False` |
+| `num-finetune-epochs` | `5` |
+| `use-cuda` | `true` |
+| `specified-cuda-device` | `null` |
+| `client-resources` | `{'num-cpus': 2, 'num-gpus': 0.5 }` |
+| `learning-rate` | `0.01` |
+| `batch-size` | `50` |
+| `model-name` | `cnncifar10` |
+| `algorithm` | `fedrep` |
## Environment Setup
-To construct the Python environment follow these steps:
+Create a new Python environment using [pyenv](https://github.com/pyenv/pyenv) and [virtualenv plugin](https://github.com/pyenv/pyenv-virtualenv), then install the baseline project:
```bash
-# Set Python 3.10
-pyenv local 3.10.12
-# Tell poetry to use python 3.10
-poetry env use 3.10.12
+# Create the environment
+pyenv virtualenv 3.10.12 fedrep-env
-# Install the base Poetry environment
-poetry install
+# Activate it
+pyenv activate fedrep-env
-# Activate the environment
-poetry shell
+# Then install the project
+pip install -e .
```
## Running the Experiments
```
-python -m fedrep.main # this will run using the default settings in the `conf/base.yaml`
+flwr run . # this will run using the default settings in the `pyproject.toml`
```
-While the config files contain a large number of settings, the ones below are the main ones you'd likely want to modify to .
+While the config files contain a large number of settings, the ones below are the main ones you'd likely want to modify.
```bash
-algorithm: fedavg, fedrep # these are currently supported
-dataset.name: cifar10, cifar100
-dataset.num_classes: 2, 5, 20 (only for CIFAR-100)
-model_name: cnncifar10, cnncifar100
+algorithm = "fedavg", "fedrep" # these are currently supported
+dataset-name = "cifar10", "cifar100"
+dataset-split-num-classes = 2, 5, 20 (only for CIFAR-100)
+model-name = "cnncifar10", "cnncifar100"
```
-
+See also for instance the configuration files for CIFAR10 and CIFAR100 under the `conf` directory.
## Expected Results
+The default algorithm used by all configuration files is `fedrep`. To use `fedavg` please change the `algorithm` property in the respective configuration file. The default federated environment consists of 100 clients.
+
+When the execution completes, a new directory `results` will be created with a json file that contains the running configurations and the results per round.
+
+> [!NOTE]
+> All plots shown below are generated using the `docs/make_plots.py` script. The script reads all json files generated by the baseline inside the `results` directory.
### CIFAR-10 (100, 2)
```
-python -m fedrep.main --config-name cifar10_100_2 algorithm=fedrep
-python -m fedrep.main --config-name cifar10_100_2 algorithm=fedavg
+flwr run . --run-config conf/cifar10_2.toml
```
### CIFAR-10 (100, 5)
```
-python -m fedrep.main --config-name cifar10_100_5 algorithm=fedrep
-python -m fedrep.main --config-name cifar10_100_5 algorithm=fedavg
+flwr run . --run-config conf/cifar10_5.toml
```
### CIFAR-100 (100, 5)
```
-python -m fedrep.main --config-name cifar100_100_5 algorithm=fedrep
-python -m fedrep.main --config-name cifar100_100_5 algorithm=fedavg
+flwr run . --run-config conf/cifar100_5.toml
```
### CIFAR-100 (100, 20)
```
-python -m fedrep.main --config-name cifar100_100_20 algorithm=fedrep
-python -m fedrep.main --config-name cifar100_100_20 algorithm=fedavg
+flwr run . --run-config conf/cifar100_20.toml
```
-
\ No newline at end of file
+
diff --git a/baselines/fedrep/_static/cifar100_100_20.png b/baselines/fedrep/_static/cifar100_100_20.png
index 2421f15ac6c6..3f97d08a4dff 100644
Binary files a/baselines/fedrep/_static/cifar100_100_20.png and b/baselines/fedrep/_static/cifar100_100_20.png differ
diff --git a/baselines/fedrep/_static/cifar100_100_5.png b/baselines/fedrep/_static/cifar100_100_5.png
index 17f25eb480c4..f22ddcedaf61 100644
Binary files a/baselines/fedrep/_static/cifar100_100_5.png and b/baselines/fedrep/_static/cifar100_100_5.png differ
diff --git a/baselines/fedrep/_static/cifar10_100_2.png b/baselines/fedrep/_static/cifar10_100_2.png
index 75ee48b2c970..1e7321f85e01 100644
Binary files a/baselines/fedrep/_static/cifar10_100_2.png and b/baselines/fedrep/_static/cifar10_100_2.png differ
diff --git a/baselines/fedrep/_static/cifar10_100_5.png b/baselines/fedrep/_static/cifar10_100_5.png
index 1d20a953f9c4..54a9143d8fd9 100644
Binary files a/baselines/fedrep/_static/cifar10_100_5.png and b/baselines/fedrep/_static/cifar10_100_5.png differ
diff --git a/baselines/fedrep/conf/cifar100_20.toml b/baselines/fedrep/conf/cifar100_20.toml
new file mode 100644
index 000000000000..2b6bb5f5e1eb
--- /dev/null
+++ b/baselines/fedrep/conf/cifar100_20.toml
@@ -0,0 +1,11 @@
+algorithm = "fedrep"
+
+# model specs
+model-name = "cnncifar100"
+
+# dataset specs
+dataset-name = "cifar100"
+dataset-split = "sample"
+dataset-split-num-classes = 20
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/conf/cifar100_5.toml b/baselines/fedrep/conf/cifar100_5.toml
new file mode 100644
index 000000000000..9903ea374028
--- /dev/null
+++ b/baselines/fedrep/conf/cifar100_5.toml
@@ -0,0 +1,11 @@
+algorithm = "fedrep"
+
+# model specs
+model-name = "cnncifar100"
+
+# dataset specs
+dataset-name = "cifar100"
+dataset-split = "sample"
+dataset-split-num-classes = 5
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/conf/cifar10_2.toml b/baselines/fedrep/conf/cifar10_2.toml
new file mode 100644
index 000000000000..307ba7101bd1
--- /dev/null
+++ b/baselines/fedrep/conf/cifar10_2.toml
@@ -0,0 +1,8 @@
+algorithm = "fedrep"
+
+# dataset specs
+dataset-name = "cifar10"
+dataset-split = "sample"
+dataset-split-num-classes = 2
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/conf/cifar10_5.toml b/baselines/fedrep/conf/cifar10_5.toml
new file mode 100644
index 000000000000..cd09c9e07ec2
--- /dev/null
+++ b/baselines/fedrep/conf/cifar10_5.toml
@@ -0,0 +1,8 @@
+algorithm = "fedrep"
+
+# dataset specs
+dataset-name = "cifar10"
+dataset-split = "sample"
+dataset-split-num-classes = 5
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/docs/make_plots.py b/baselines/fedrep/docs/make_plots.py
new file mode 100644
index 000000000000..9474e87a2471
--- /dev/null
+++ b/baselines/fedrep/docs/make_plots.py
@@ -0,0 +1,50 @@
+"""Generate plots from json files."""
+
+import json
+import os
+from typing import List, Tuple
+
+import matplotlib.pyplot as plt
+
+# Get the current working directory
+DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def read_from_results(path: str) -> Tuple[str, str, List[float], str, str]:
+ """Load the json file with recorded configurations and results."""
+ with open(path, "r", encoding="UTF-8") as fin:
+ data = json.load(fin)
+ algorithm = data["run_config"]["algorithm"]
+ model = data["run_config"]["model-name"]
+ accuracies = [res["accuracy"] * 100 for res in data["round_res"]]
+ dataset = data["run_config"]["dataset-name"]
+ num_classes = data["run_config"]["dataset-split-num-classes"]
+
+ return algorithm, model, accuracies, dataset, num_classes
+
+
+def make_plot(dir_path: str, plt_title: str) -> None:
+ """Given a directory with json files, generated a plot using the provided title."""
+ plt.figure()
+ with os.scandir(dir_path) as files:
+ for file in files:
+ file_name = os.path.join(dir_path, file.name)
+ print(file_name, flush=True)
+ algo, m, acc, d, n = read_from_results(file_name)
+ rounds = [i + 1 for i in range(len(acc))]
+ print(f"Max accuracy ({algo}): {max(acc):.2f}")
+ plt.plot(rounds, acc, label=f"{algo}-{d}-{n}classes")
+ plt.xlabel("Rounds")
+ plt.ylabel("Accuracy")
+ plt.title(plt_title)
+ plt.grid()
+ plt.legend()
+ plt.savefig(os.path.join(DIR, f"{plt_title}-{algo}"))
+
+
+if __name__ == "__main__":
+ # Plot results generated by the baseline.
+ # Combine them into a full file path.
+ res_dir = os.path.join(DIR, "../results/")
+ title = "Federated Accuracy over Rounds"
+ make_plot(res_dir, plt_title=title)
diff --git a/baselines/fedrep/fedrep/__init__.py b/baselines/fedrep/fedrep/__init__.py
index a5e567b59135..f2dbc04ee34e 100644
--- a/baselines/fedrep/fedrep/__init__.py
+++ b/baselines/fedrep/fedrep/__init__.py
@@ -1 +1 @@
-"""Template baseline package."""
+"""fedrep: A Flower Baseline."""
diff --git a/baselines/fedrep/fedrep/base_model.py b/baselines/fedrep/fedrep/base_model.py
index e6a74c01bf9b..82a18cec2622 100644
--- a/baselines/fedrep/fedrep/base_model.py
+++ b/baselines/fedrep/fedrep/base_model.py
@@ -1,61 +1,23 @@
-"""Abstract class for splitting a model into body and head."""
+"""fedrep: A Flower Baseline."""
-import os
+import collections
from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
+from typing import Any, Dict, List, OrderedDict, Tuple, Union
-import numpy as np
import torch
-import torch.nn as nn
-from omegaconf import DictConfig
-from torch import Tensor
+from torch import Tensor, nn
from torch.utils.data import DataLoader
-from fedrep.constants import (
+from flwr.common import Context, NDArrays, ParametersRecord, array_from_numpy
+
+from .constants import (
DEFAULT_FINETUNE_EPOCHS,
DEFAULT_LOCAL_TRAIN_EPOCHS,
DEFAULT_REPRESENTATION_EPOCHS,
+ FEDREP_HEAD_STATE,
)
-def get_device(
- use_cuda: bool = True, specified_device: Optional[int] = None
-) -> torch.device:
- """Get the tensor device.
-
- Args:
- use_cuda: Flag indicates whether to use CUDA or not. Defaults to True.
- specified_device: Specified cuda device to use. Defaults to None.
-
- Raises
- ------
- ValueError: Specified device not in CUDA_VISIBLE_DEVICES.
-
- Returns
- -------
- The selected or fallbacked device.
- """
- device = torch.device("cpu")
- if use_cuda and torch.cuda.is_available():
- if specified_device is not None:
- cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
- if cuda_visible_devices is not None:
- devices = [int(d) for d in cuda_visible_devices.split(",")]
- if specified_device in devices:
- device = torch.device(f"cuda:{specified_device}")
- else:
- raise ValueError(
- f"Specified device {specified_device}"
- " not in CUDA_VISIBLE_DEVICES"
- )
- else:
- print("CUDA_VISIBLE_DEVICES not exists, using torch.device('cuda').")
- else:
- device = torch.device("cuda")
-
- return device
-
-
class ModelSplit(ABC, nn.Module):
"""Abstract class for splitting a model into body and head."""
@@ -110,7 +72,7 @@ def head(self, state_dict: OrderedDict[str, Tensor]) -> None:
"""
self._head.load_state_dict(state_dict, strict=True)
- def get_parameters(self) -> List[np.ndarray]:
+ def get_parameters(self) -> NDArrays:
"""Get model parameters.
Returns
@@ -164,65 +126,86 @@ class ModelManager(ABC):
def __init__(
self,
- client_id: int,
- config: DictConfig,
+ context: Context,
trainloader: DataLoader,
testloader: DataLoader,
- client_save_path: Optional[str],
model_split_class: Any, # ModelSplit
):
"""Initialize the attributes of the model manager.
Args:
- client_id: The id of the client.
- config: Dict containing the configurations to be used by the manager.
+ context: The context of the current run.
trainloader: Client train dataloader.
testloader: Client test dataloader.
- client_save_path: Path to save the client model head state.
model_split_class: Class to be used to split the model into body and head \
(concrete implementation of ModelSplit).
"""
super().__init__()
- self.config = config
- self.client_id = client_id
+ self.context = context
self.trainloader = trainloader
self.testloader = testloader
- self.device = get_device(
- use_cuda=getattr(self.config, "use_cuda", True),
- specified_device=getattr(self.config, "specified_device", None),
- )
- self.client_save_path = client_save_path
- self.learning_rate = config.get("learning_rate", 0.01)
- self.momentum = config.get("momentum", 0.5)
+ self.learning_rate = self.context.run_config.get("learning-rate", 0.01)
+ self.momentum = self.context.run_config.get("momentum", 0.5)
self._model: ModelSplit = model_split_class(self._create_model())
@abstractmethod
def _create_model(self) -> nn.Module:
- """Return model to be splitted into head and body."""
+ """Return model to be split into head and body."""
@property
def model(self) -> ModelSplit:
"""Return model."""
return self._model
- def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
+ def _load_client_state(self) -> None:
+ """Load client model head state from context state; used only by FedRep."""
+ # First, check if the fedrep head state is set in the context state.
+ if self.context.state.parameters_records.get(FEDREP_HEAD_STATE):
+ state_dict = collections.OrderedDict(
+ {
+ k: torch.from_numpy(v.numpy())
+ for k, v in self.context.state.parameters_records[
+ FEDREP_HEAD_STATE
+ ].items()
+ }
+ )
+ # Second, check if the parameters records have values stored and load
+ # the state; this check is useful for the first time the model is
+ # tested and the head state might be empty.
+ if state_dict:
+ self._model.head.load_state_dict(state_dict)
+
+ def _save_client_state(self) -> None:
+ """Save client model head state inside context state; used only by FedRep."""
+ # Check if the fedrep head state is set in the context state.
+ if FEDREP_HEAD_STATE in self.context.state.parameters_records:
+ head_state = self._model.head.state_dict()
+ head_state_np = {k: v.detach().cpu().numpy() for k, v in head_state.items()}
+ head_state_arr = collections.OrderedDict(
+ {k: array_from_numpy(v) for k, v in head_state_np.items()}
+ )
+ head_state_prec = ParametersRecord(head_state_arr)
+ self.context.state.parameters_records[FEDREP_HEAD_STATE] = head_state_prec
+
+ def train(
+ self, device: torch.device
+ ) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
"""Train the model maintained in self.model.
Returns
-------
Dict containing the train metrics.
"""
- # Load client state (head) if client_save_path is not None and it is not empty
- if self.client_save_path is not None and os.path.isfile(self.client_save_path):
- self._model.head.load_state_dict(torch.load(self.client_save_path))
+ # Load state.
+ self._load_client_state()
num_local_epochs = DEFAULT_LOCAL_TRAIN_EPOCHS
- if hasattr(self.config, "num_local_epochs"):
- num_local_epochs = int(self.config.num_local_epochs)
+ if "num-local-epochs" in self.context.run_config:
+ num_local_epochs = int(self.context.run_config["num-local-epochs"])
num_rep_epochs = DEFAULT_REPRESENTATION_EPOCHS
- if hasattr(self.config, "num_rep_epochs"):
- num_rep_epochs = int(self.config.num_rep_epochs)
+ if self.context.run_config["num-rep-epochs"] in self.context.run_config:
+ num_rep_epochs = int(self.context.run_config["num-rep-epochs"])
criterion = torch.nn.CrossEntropyLoss()
weights = [v for k, v in self._model.named_parameters() if "weight" in k]
@@ -238,6 +221,7 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
correct, total = 0, 0
loss: torch.Tensor = 0.0
+ self._model.to(device)
self._model.train()
for i in range(num_local_epochs + num_rep_epochs):
if i < num_local_epochs:
@@ -247,10 +231,9 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
self._model.enable_body()
self._model.disable_head()
for batch in self.trainloader:
- images = batch["img"]
- labels = batch["label"]
- outputs = self._model(images.to(self.device))
- labels = labels.to(self.device)
+ images = batch["img"].to(device)
+ labels = batch["label"].to(device)
+ outputs = self._model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
@@ -258,35 +241,35 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
- # Save client state (head)
- if self.client_save_path is not None:
- torch.save(self._model.head.state_dict(), self.client_save_path)
+ # Save state.
+ self._save_client_state()
return {"loss": loss.item(), "accuracy": correct / total}
- def test(self) -> Dict[str, float]:
+ def test(self, device: torch.device) -> Dict[str, float]:
"""Test the model maintained in self.model.
Returns
-------
Dict containing the test metrics.
"""
- # Load client state (head)
- if self.client_save_path is not None and os.path.isfile(self.client_save_path):
- self._model.head.load_state_dict(torch.load(self.client_save_path))
+ # Load state.
+ self._load_client_state()
num_finetune_epochs = DEFAULT_FINETUNE_EPOCHS
- if hasattr(self.config, "num_finetune_epochs"):
- num_finetune_epochs = int(self.config.num_finetune_epochs)
+ if "num-finetune-epochs" in self.context.run_config:
+ num_finetune_epochs = int(self.context.run_config["num-finetune-epochs"])
- if num_finetune_epochs > 0 and self.config.get("enable_finetune", False):
+ if num_finetune_epochs > 0 and self.context.run_config.get(
+ "enable-finetune", False
+ ):
optimizer = torch.optim.SGD(self._model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss()
self._model.train()
for _ in range(num_finetune_epochs):
for batch in self.trainloader:
- images = batch["img"].to(self.device)
- labels = batch["label"].to(self.device)
+ images = batch["img"]
+ labels = batch["label"]
outputs = self._model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
@@ -296,11 +279,12 @@ def test(self) -> Dict[str, float]:
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
+ self._model.to(device)
self._model.eval()
with torch.no_grad():
for batch in self.testloader:
- images = batch["img"].to(self.device)
- labels = batch["label"].to(self.device)
+ images = batch["img"].to(device)
+ labels = batch["label"].to(device)
outputs = self._model(images)
loss += criterion(outputs, labels).item()
total += labels.size(0)
diff --git a/baselines/fedrep/fedrep/client.py b/baselines/fedrep/fedrep/client.py
deleted file mode 100644
index f857fd2cf82a..000000000000
--- a/baselines/fedrep/fedrep/client.py
+++ /dev/null
@@ -1,319 +0,0 @@
-"""Client implementation - can call FedPep and FedAvg clients."""
-
-from collections import OrderedDict
-from pathlib import Path
-from typing import Callable, Dict, List, Tuple, Type, Union
-
-import numpy as np
-import torch
-from flwr.client import Client, NumPyClient
-from flwr.common import NDArrays, Scalar
-from flwr_datasets import FederatedDataset
-from flwr_datasets.partitioner import PathologicalPartitioner
-from flwr_datasets.preprocessor import Merger
-from omegaconf import DictConfig
-from torch.utils.data import DataLoader
-from torchvision import transforms
-
-from fedrep.constants import MEAN, STD, Algorithm
-from fedrep.models import CNNCifar10ModelManager, CNNCifar100ModelManager
-
-PROJECT_DIR = Path(__file__).parent.parent.absolute()
-
-FEDERATED_DATASET = None
-
-
-class BaseClient(NumPyClient):
- """Implementation of Federated Averaging (FedAvg) Client."""
-
- # pylint: disable=R0913
- def __init__(
- self,
- client_id: int,
- trainloader: DataLoader,
- testloader: DataLoader,
- config: DictConfig,
- model_manager_class: Union[
- Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager]
- ],
- client_state_save_path: str = "",
- ):
- """Initialize client attributes.
-
- Args:
- client_id: The client ID.
- trainloader: Client train data loader.
- testloader: Client test data loader.
- config: dictionary containing the client configurations.
- model_manager_class: class to be used as the model manager.
- client_state_save_path: Path for saving model head parameters.
- (Just for FedRep). Defaults to "".
- """
- super().__init__()
-
- self.client_id = client_id
- self.client_state_save_path = (
- (client_state_save_path + f"/client_{self.client_id}")
- if client_state_save_path != ""
- else None
- )
- self.model_manager = model_manager_class(
- client_id=self.client_id,
- config=config,
- trainloader=trainloader,
- testloader=testloader,
- client_save_path=self.client_state_save_path,
- )
-
- def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
- """Return the current local model parameters."""
- return self.model_manager.model.get_parameters()
-
- def set_parameters(
- self, parameters: List[np.ndarray], evaluate: bool = False
- ) -> None:
- """Set the local model parameters to the received parameters.
-
- Args:
- parameters: parameters to set the model to.
- """
- _ = evaluate
- model_keys = [
- k
- for k in self.model_manager.model.state_dict().keys()
- if k.startswith("_body") or k.startswith("_head")
- ]
- params_dict = zip(model_keys, parameters)
-
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
-
- self.model_manager.model.set_parameters(state_dict)
-
- def perform_train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
- """Perform local training to the whole model.
-
- Returns
- -------
- Dict with the train metrics.
- """
- self.model_manager.model.enable_body()
- self.model_manager.model.enable_head()
-
- return self.model_manager.train()
-
- def fit(
- self, parameters: NDArrays, config: Dict[str, Scalar]
- ) -> Tuple[NDArrays, int, Dict[str, Union[bool, bytes, float, int, str]]]:
- """Train the provided parameters using the locally held dataset.
-
- Args:
- parameters: The current (global) model parameters.
- config: configuration parameters for training sent by the server.
-
- Returns
- -------
- Tuple containing the locally updated model parameters, \
- the number of examples used for training and \
- the training metrics.
- """
- self.set_parameters(parameters)
-
- train_results = self.perform_train()
-
- # Update train history
- print("<------- TRAIN RESULTS -------> :", train_results)
-
- return self.get_parameters(config), self.model_manager.train_dataset_size(), {}
-
- def evaluate(
- self, parameters: NDArrays, config: Dict[str, Scalar]
- ) -> Tuple[float, int, Dict[str, Union[bool, bytes, float, int, str]]]:
- """Evaluate the provided global parameters using the locally held dataset.
-
- Args:
- parameters: The current (global) model parameters.
- config: configuration parameters for training sent by the server.
-
- Returns
- -------
- Tuple containing the test loss, \
- the number of examples used for evaluation and \
- the evaluation metrics.
- """
- self.set_parameters(parameters, evaluate=True)
-
- # Test the model
- test_results = self.model_manager.test()
- print("<------- TEST RESULTS -------> :", test_results)
-
- return (
- test_results.get("loss", 0.0),
- self.model_manager.test_dataset_size(),
- {k: v for k, v in test_results.items() if not isinstance(v, (dict, list))},
- )
-
-
-class FedRepClient(BaseClient):
- """Implementation of Federated Personalization (FedRep) Client."""
-
- def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
- """Return the current local body parameters."""
- return [
- val.cpu().numpy()
- for val in self.model_manager.model.body.state_dict().values()
- ]
-
- def set_parameters(self, parameters: List[np.ndarray], evaluate=False) -> None:
- """Set the local body parameters to the received parameters.
-
- Args:
- parameters: parameters to set the body to.
- evaluate: whether the client is evaluating or not.
- """
- model_keys = [
- k
- for k in self.model_manager.model.state_dict().keys()
- if k.startswith("_body")
- ]
-
- if not evaluate:
- # Only update client's local head if it hasn't trained yet
- model_keys.extend(
- [
- k
- for k in self.model_manager.model.state_dict().keys()
- if k.startswith("_head")
- ]
- )
-
- state_dict = OrderedDict(
- (k, torch.from_numpy(v)) for k, v in zip(model_keys, parameters)
- )
-
- self.model_manager.model.set_parameters(state_dict)
-
-
-# pylint: disable=E1101, W0603
-def get_client_fn_simulation(
- config: DictConfig, client_state_save_path: str = ""
-) -> Callable[[str], Client]:
- """Generate the client function that creates the Flower Clients.
-
- Parameters
- ----------
- model : DictConfig
- The model configuration.
- cleint_state_save_path : str
- The path to save the client state.
-
- Returns
- -------
- Tuple[Callable[[str], FlowerClient], DataLoader]
- A tuple containing the client function that creates Flower Clients and
- the DataLoader that will be used for testing
- """
- assert config.model_name.lower() in [
- "cnncifar10",
- "cnncifar100",
- ], f"Model {config.model_name} not implemented"
-
- # - you can define your own data transformation strategy here -
- # These transformations are from the official repo
- train_data_transform = transforms.Compose(
- [
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(MEAN[config.dataset.name], STD[config.dataset.name]),
- ]
- )
- test_data_transform = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize(MEAN[config.dataset.name], STD[config.dataset.name]),
- ]
- )
-
- use_fine_label = False
- if config.dataset.name.lower() == "cifar100":
- use_fine_label = True
-
- partitioner = PathologicalPartitioner(
- num_partitions=config.num_clients,
- partition_by="fine_label" if use_fine_label else "label",
- num_classes_per_partition=config.dataset.num_classes,
- class_assignment_mode="random",
- shuffle=True,
- seed=config.dataset.seed,
- )
-
- global FEDERATED_DATASET
- if FEDERATED_DATASET is None:
- FEDERATED_DATASET = FederatedDataset(
- dataset=config.dataset.name.lower(),
- partitioners={"all": partitioner},
- preprocessor=Merger({"all": ("train", "test")}),
- )
-
- def apply_train_transforms(batch):
- """Apply transforms for train data to the partition from FederatedDataset."""
- batch["img"] = [train_data_transform(img) for img in batch["img"]]
- if use_fine_label:
- batch["label"] = batch["fine_label"]
- return batch
-
- def apply_test_transforms(batch):
- """Apply transforms for test data to the partition from FederatedDataset."""
- batch["img"] = [test_data_transform(img) for img in batch["img"]]
- if use_fine_label:
- batch["label"] = batch["fine_label"]
- return batch
-
- # pylint: disable=E1101
- def client_fn(cid: str) -> Client:
- """Create a Flower client representing a single organization."""
- cid_use = int(cid)
-
- partition = FEDERATED_DATASET.load_partition(cid_use, split="all")
-
- partition_train_test = partition.train_test_split(
- train_size=config.dataset.fraction, shuffle=True, seed=config.dataset.seed
- )
-
- trainset = partition_train_test["train"].with_transform(apply_train_transforms)
- testset = partition_train_test["test"].with_transform(apply_test_transforms)
-
- trainloader = DataLoader(trainset, config.batch_size, shuffle=True)
- testloader = DataLoader(testset, config.batch_size)
-
- model_manager_class: Union[
- Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager]
- ]
- if config.model_name.lower() == "cnncifar10":
- model_manager_class = CNNCifar10ModelManager
- elif config.model_name.lower() == "cnncifar100":
- model_manager_class = CNNCifar100ModelManager
- else:
- raise NotImplementedError(
- f"Model {config.model_name} not implemented, check name."
- )
-
- if config.algorithm.lower() == Algorithm.FEDREP.value:
- return FedRepClient( # type: ignore[attr-defined]
- client_id=cid_use,
- trainloader=trainloader,
- testloader=testloader,
- config=config,
- model_manager_class=model_manager_class,
- client_state_save_path=client_state_save_path,
- ).to_client()
- return BaseClient( # type: ignore[attr-defined]
- client_id=cid_use,
- trainloader=trainloader,
- testloader=testloader,
- config=config,
- model_manager_class=model_manager_class,
- client_state_save_path=client_state_save_path,
- ).to_client()
-
- return client_fn
diff --git a/baselines/fedrep/fedrep/client_app.py b/baselines/fedrep/fedrep/client_app.py
new file mode 100644
index 000000000000..78755f4484aa
--- /dev/null
+++ b/baselines/fedrep/fedrep/client_app.py
@@ -0,0 +1,186 @@
+"""fedrep: A Flower Baseline."""
+
+from collections import OrderedDict
+from typing import Dict, List, Tuple, Union
+
+import torch
+
+from flwr.client import ClientApp, NumPyClient
+from flwr.client.client import Client
+from flwr.common import Context, NDArrays, ParametersRecord, Scalar
+
+from .constants import FEDREP_HEAD_STATE, Algorithm
+from .dataset import load_data
+from .models import CNNCifar10ModelManager, CNNCifar100ModelManager
+from .utils import get_model_manager_class
+
+
+class BaseClient(NumPyClient):
+ """Implementation of Federated Averaging (FedAvg) Client."""
+
+ # pylint: disable=R0913
+ def __init__(
+ self, model_manager: Union[CNNCifar10ModelManager, CNNCifar100ModelManager]
+ ):
+ """Initialize client attributes.
+
+ Args:
+ model_manager: the model manager object
+ """
+ super().__init__()
+ self.model_manager = model_manager
+
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
+ """Return the current local model parameters."""
+ return self.model_manager.model.get_parameters()
+
+ def set_parameters(self, parameters: NDArrays, evaluate: bool = False) -> None:
+ """Set the local model parameters to the received parameters.
+
+ Args:
+ parameters: parameters to set the model to.
+ evaluate: whether to evaluate or not.
+ """
+ _ = evaluate
+ model_keys = [
+ k
+ for k in self.model_manager.model.state_dict().keys()
+ if k.startswith("_body") or k.startswith("_head")
+ ]
+ params_dict = zip(model_keys, parameters)
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
+ self.model_manager.model.set_parameters(state_dict)
+
+ def perform_train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
+ """Perform local training to the whole model.
+
+ Returns
+ -------
+ Dict with the train metrics.
+ """
+ self.model_manager.model.enable_body()
+ self.model_manager.model.enable_head()
+
+ return self.model_manager.train(self.device)
+
+ def fit(
+ self, parameters: NDArrays, config: Dict[str, Scalar]
+ ) -> Tuple[NDArrays, int, Dict[str, Union[bool, bytes, float, int, str]]]:
+ """Train the provided parameters using the locally held dataset.
+
+ Args:
+ parameters: The current (global) model parameters.
+ config: configuration parameters for training sent by the server.
+
+ Returns
+ -------
+ Tuple containing the locally updated model parameters, \
+ the number of examples used for training and \
+ the training metrics.
+ """
+ self.set_parameters(parameters)
+ self.perform_train()
+
+ return self.get_parameters(config), self.model_manager.train_dataset_size(), {}
+
+ def evaluate(
+ self, parameters: NDArrays, config: Dict[str, Scalar]
+ ) -> Tuple[float, int, Dict[str, Union[bool, bytes, float, int, str]]]:
+ """Evaluate the provided global parameters using the locally held dataset.
+
+ Args:
+ parameters: The current (global) model parameters.
+ config: configuration parameters for training sent by the server.
+
+ Returns
+ -------
+ Tuple containing the test loss, \
+ the number of examples used for evaluation and \
+ the evaluation metrics.
+ """
+ self.set_parameters(parameters, evaluate=True)
+
+ # Test the model
+ test_results = self.model_manager.test(self.device)
+
+ return (
+ test_results.get("loss", 0.0),
+ self.model_manager.test_dataset_size(),
+ {k: v for k, v in test_results.items() if not isinstance(v, (dict, list))},
+ )
+
+
+class FedRepClient(BaseClient):
+ """Implementation of Federated Personalization (FedRep) Client."""
+
+ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
+ """Return the current local body parameters."""
+ return [
+ val.cpu().numpy()
+ for val in self.model_manager.model.body.state_dict().values()
+ ]
+
+ def set_parameters(self, parameters: NDArrays, evaluate: bool = False) -> None:
+ """Set the local body parameters to the received parameters.
+
+ Args:
+ parameters: parameters to set the body to.
+ evaluate: whether the client is evaluating or not.
+ """
+ model_keys = [
+ k
+ for k in self.model_manager.model.state_dict().keys()
+ if k.startswith("_body")
+ ]
+
+ if not evaluate:
+ # Only update client's local head if it hasn't trained yet
+ model_keys.extend(
+ [
+ k
+ for k in self.model_manager.model.state_dict().keys()
+ if k.startswith("_head")
+ ]
+ )
+
+ state_dict = OrderedDict(
+ (k, torch.from_numpy(v)) for k, v in zip(model_keys, parameters)
+ )
+
+ self.model_manager.model.set_parameters(state_dict)
+
+
+def client_fn(context: Context) -> Client:
+ """Construct a Client that will be run in a ClientApp."""
+ model_manager_class = get_model_manager_class(context)
+ algorithm = str(context.run_config["algorithm"]).lower()
+ partition_id = int(context.node_config["partition-id"])
+ num_partitions = int(context.node_config["num-partitions"])
+ trainloader, valloader = load_data(
+ partition_id, num_partitions, context
+ ) # load the data
+ if algorithm == Algorithm.FEDAVG.value:
+ client_class = BaseClient
+ elif algorithm == Algorithm.FEDREP.value:
+ # This state variable will only be used by the FedRep algorithm.
+ # We only need to initialize once, since client_fn will be called
+ # again at every invocation of the ClientApp.
+ if FEDREP_HEAD_STATE not in context.state.parameters_records:
+ context.state.parameters_records[FEDREP_HEAD_STATE] = ParametersRecord()
+ client_class = FedRepClient
+ else:
+ raise RuntimeError(f"Unknown algorithm {algorithm}.")
+
+ model_manager_obj = model_manager_class(
+ context=context, trainloader=trainloader, testloader=valloader
+ )
+
+ # Return client object.
+ client = client_class(model_manager_obj).to_client()
+ return client
+
+
+# Flower ClientApp
+app = ClientApp(client_fn)
diff --git a/baselines/fedrep/fedrep/conf/base.yaml b/baselines/fedrep/fedrep/conf/base.yaml
deleted file mode 100644
index 0d74c4fe78b6..000000000000
--- a/baselines/fedrep/fedrep/conf/base.yaml
+++ /dev/null
@@ -1,46 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar10
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null # the ID of cuda device, if null, then use defaults torch.device("cuda")
-
-dataset:
- name: cifar10
- split: sample
- num_classes: 2
- seed: 42
- num_clients: ${num_clients}
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar10
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml b/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml
deleted file mode 100644
index 30f9fd209d58..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar100
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar100
- num_classes: 20
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar100
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml b/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml
deleted file mode 100644
index e0add8f03b45..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar100
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar100
- num_classes: 5
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar100
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml b/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml
deleted file mode 100644
index 83ee34a298ae..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar10
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar10
- num_classes: 2
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar10.CNNCifar10
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml b/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml
deleted file mode 100644
index 0cbd104406f0..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar10
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar10
- num_classes: 5
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar10.CNNCifar10
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/constants.py b/baselines/fedrep/fedrep/constants.py
index 27e68f2b786c..a4527e1f36d4 100644
--- a/baselines/fedrep/fedrep/constants.py
+++ b/baselines/fedrep/fedrep/constants.py
@@ -1,7 +1,16 @@
-"""Constants used in machine learning pipeline."""
+"""fedrep: A Flower Baseline."""
from enum import Enum
+DEFAULT_LOCAL_TRAIN_EPOCHS: int = 10
+DEFAULT_FINETUNE_EPOCHS: int = 5
+DEFAULT_REPRESENTATION_EPOCHS: int = 1
+FEDREP_HEAD_STATE = "fedrep_head_state"
+
+MEAN = {"cifar10": [0.485, 0.456, 0.406], "cifar100": [0.507, 0.487, 0.441]}
+
+STD = {"cifar10": [0.229, 0.224, 0.225], "cifar100": [0.267, 0.256, 0.276]}
+
class Algorithm(Enum):
"""Algorithm names."""
@@ -10,10 +19,8 @@ class Algorithm(Enum):
FEDAVG = "fedavg"
-DEFAULT_LOCAL_TRAIN_EPOCHS: int = 10
-DEFAULT_FINETUNE_EPOCHS: int = 5
-DEFAULT_REPRESENTATION_EPOCHS: int = 1
+class ModelDatasetName(Enum):
+ """Dataset names."""
-MEAN = {"cifar10": [0.485, 0.456, 0.406], "cifar100": [0.507, 0.487, 0.441]}
-
-STD = {"cifar10": [0.229, 0.224, 0.225], "cifar100": [0.267, 0.256, 0.276]}
+ CNN_CIFAR_10 = "cnncifar10"
+ CNN_CIFAR_100 = "cnncifar100"
diff --git a/baselines/fedrep/fedrep/dataset.py b/baselines/fedrep/fedrep/dataset.py
index a616e38ae220..621baf98249a 100644
--- a/baselines/fedrep/fedrep/dataset.py
+++ b/baselines/fedrep/fedrep/dataset.py
@@ -1 +1,107 @@
-"""FedRep uses flwr-datasets."""
+"""fedrep: A Flower Baseline."""
+
+from typing import Tuple
+
+from flwr_datasets import FederatedDataset
+from flwr_datasets.partitioner import PathologicalPartitioner
+from flwr_datasets.preprocessor import Merger
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+from flwr.common import Context
+
+from .constants import MEAN, STD
+
+FDS = None # Cache FederatedDataset
+
+
+def load_data(
+ partition_id: int, num_partitions: int, context: Context
+) -> Tuple[DataLoader, DataLoader]:
+ """Split the data and return training and testing data for the specified partition.
+
+ Parameters
+ ----------
+ partition_id : int
+ Partition number for which to retrieve the corresponding data.
+ num_partitions : int
+ Total number of partitions.
+ context: Context
+ the context of the current run.
+
+ Returns
+ -------
+ data : Tuple[DataLoader, DataLoader]
+ A tuple with the training and testing data for the current partition_id.
+ """
+ batch_size = int(context.run_config["batch-size"])
+ dataset_name = str(context.run_config["dataset-name"]).lower()
+ dataset_split_num_classes = int(context.run_config["dataset-split-num-classes"])
+ dataset_split_seed = int(context.run_config["dataset-split-seed"])
+ dataset_split_fraction = float(context.run_config["dataset-split-fraction"])
+
+ # - you can define your own data transformation strategy here -
+ # These transformations are from the official repo
+ train_data_transform = transforms.Compose(
+ [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN[dataset_name], STD[dataset_name]),
+ ]
+ )
+ test_data_transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN[dataset_name], STD[dataset_name]),
+ ]
+ )
+
+ use_fine_label = False
+ if dataset_name == "cifar100":
+ use_fine_label = True
+
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="fine_label" if use_fine_label else "label",
+ num_classes_per_partition=dataset_split_num_classes,
+ class_assignment_mode="random",
+ shuffle=True,
+ seed=dataset_split_seed,
+ )
+
+ global FDS # pylint: disable=global-statement
+ if FDS is None:
+ FDS = FederatedDataset(
+ dataset=dataset_name,
+ partitioners={"all": partitioner},
+ preprocessor=Merger({"all": ("train", "test")}),
+ )
+
+ def apply_train_transforms(batch: Dataset) -> Dataset:
+ """Apply transforms for train data to the partition from FederatedDataset."""
+ batch["img"] = [train_data_transform(img) for img in batch["img"]]
+ if use_fine_label:
+ batch["label"] = batch["fine_label"]
+ return batch
+
+ def apply_test_transforms(batch: Dataset) -> Dataset:
+ """Apply transforms for test data to the partition from FederatedDataset."""
+ batch["img"] = [test_data_transform(img) for img in batch["img"]]
+ if use_fine_label:
+ batch["label"] = batch["fine_label"]
+ return batch
+
+ partition = FDS.load_partition(partition_id, split="all")
+
+ partition_train_test = partition.train_test_split(
+ train_size=dataset_split_fraction, shuffle=True, seed=dataset_split_seed
+ )
+
+ trainset = partition_train_test["train"].with_transform(apply_train_transforms)
+ testset = partition_train_test["test"].with_transform(apply_test_transforms)
+
+ trainloader = DataLoader(trainset, batch_size, shuffle=True)
+ testloader = DataLoader(testset, batch_size)
+
+ return trainloader, testloader
diff --git a/baselines/fedrep/fedrep/dataset_preparation.py b/baselines/fedrep/fedrep/dataset_preparation.py
deleted file mode 100644
index a616e38ae220..000000000000
--- a/baselines/fedrep/fedrep/dataset_preparation.py
+++ /dev/null
@@ -1 +0,0 @@
-"""FedRep uses flwr-datasets."""
diff --git a/baselines/fedrep/fedrep/main.py b/baselines/fedrep/fedrep/main.py
deleted file mode 100644
index 223b98aa21fa..000000000000
--- a/baselines/fedrep/fedrep/main.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""Create and connect the building blocks for your experiments; start the simulation.
-
-It includes processioning the dataset, instantiate strategy, specify how the global
-model is going to be evaluated, etc. At the end, this script saves the results.
-"""
-
-from pathlib import Path
-from typing import List, Tuple
-
-import flwr as fl
-import hydra
-from flwr.common.parameter import ndarrays_to_parameters
-from flwr.common.typing import Metrics
-from hydra.core.hydra_config import HydraConfig
-from hydra.utils import instantiate
-from omegaconf import DictConfig, OmegaConf
-
-from fedrep.utils import (
- get_client_fn,
- get_create_model_fn,
- plot_metric_from_history,
- save_results_as_pickle,
- set_client_state_save_path,
- set_client_strategy,
-)
-
-
-@hydra.main(config_path="conf", config_name="base", version_base=None)
-def main(cfg: DictConfig) -> None:
- """Run the baseline.
-
- Parameterss
- ----------
- cfg : DictConfig
- An omegaconf object that stores the hydra config.
- """
- # Print parsed config
- print(OmegaConf.to_yaml(cfg))
-
- # set client strategy
- cfg = set_client_strategy(cfg)
-
- # Create directory to store client states if it does not exist
- # Client state has subdirectories with the name of current time
- client_state_save_path = set_client_state_save_path()
-
- # Define your clients
- # Get client function
- client_fn = get_client_fn(config=cfg, client_state_save_path=client_state_save_path)
-
- # get a function that will be used to construct the config that the client's
- # fit() method will received
- def get_on_fit_config():
- def fit_config_fn(server_round: int):
- # resolve and convert to python dict
- fit_config = OmegaConf.to_container(cfg.fit_config, resolve=True)
- _ = server_round
- return fit_config
-
- return fit_config_fn
-
- # get a function that will be used to construct the model
- create_model, split = get_create_model_fn(cfg)
-
- model = split(create_model())
-
- def evaluate_metrics_aggregation_fn(
- eval_metrics: List[Tuple[int, Metrics]]
- ) -> Metrics:
- weights, accuracies = [], []
- for num_examples, metric in eval_metrics:
- weights.append(num_examples)
- accuracies.append(metric["accuracy"] * num_examples)
- accuracy = sum(accuracies) / sum(weights) # type: ignore[arg-type]
- return {"accuracy": accuracy}
-
- # Define your strategy
- strategy = instantiate(
- cfg.strategy,
- initial_parameters=ndarrays_to_parameters(model.get_parameters()),
- on_fit_config_fn=get_on_fit_config(),
- evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
- )
-
- # Start Simulation
- history = fl.simulation.start_simulation(
- client_fn=client_fn,
- num_clients=cfg.num_clients,
- config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
- client_resources={
- "num_cpus": cfg.client_resources.num_cpus,
- "num_gpus": cfg.client_resources.num_gpus,
- },
- strategy=strategy,
- )
-
- # Experiment completed. Now we save the results and
- # generate plots using the `history`
- print("................")
- print(history)
-
- # Save your results
- save_path = Path(HydraConfig.get().runtime.output_dir)
-
- # save results as a Python pickle using a file_path
- # the directory created by Hydra for each run
- save_results_as_pickle(history, file_path=save_path)
- # plot results and include them in the readme
- strategy_name = strategy.__class__.__name__
- file_suffix: str = (
- f"_{strategy_name}"
- f"_C={cfg.num_clients}"
- f"_B={cfg.batch_size}"
- f"_E={cfg.num_local_epochs}"
- f"_R={cfg.num_rounds}"
- f"_lr={cfg.learning_rate}"
- )
-
- plot_metric_from_history(history, save_path, (file_suffix))
-
-
-if __name__ == "__main__":
- main()
diff --git a/baselines/fedrep/fedrep/models.py b/baselines/fedrep/fedrep/models.py
index b230f4e49766..314d7cc91688 100644
--- a/baselines/fedrep/fedrep/models.py
+++ b/baselines/fedrep/fedrep/models.py
@@ -1,11 +1,11 @@
-"""Model, model manager and model split for CIFAR-10 and CIFAR-100."""
+"""fedrep: A Flower Baseline."""
from typing import Tuple
import torch
-import torch.nn as nn
+from torch import nn
-from fedrep.base_model import ModelManager, ModelSplit
+from .base_model import ModelManager, ModelSplit
# pylint: disable=W0223
@@ -58,17 +58,12 @@ class CNNCifar10ModelManager(ModelManager):
"""Manager for models with Body/Head split."""
def __init__(self, **kwargs):
- """Initialize the attributes of the model manager.
-
- Args:
- client_id: The id of the client.
- config: Dict containing the configurations to be used by the manager.
- """
+ """Initialize the attributes of the model manager."""
super().__init__(model_split_class=CNNCifar10ModelSplit, **kwargs)
def _create_model(self) -> nn.Module:
- """Return CNNCifar10 model to be splitted into head and body."""
- return CNNCifar10().to(self.device)
+ """Return CNNCifar10 model to be split into head and body."""
+ return CNNCifar10()
# pylint: disable=W0223
@@ -104,6 +99,11 @@ def __init__(self):
self.head = nn.Sequential(nn.Linear(128, 100))
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass of the model."""
+ x = self.body(x)
+ return self.head(x)
+
class CNNCifar100ModelSplit(ModelSplit):
"""Split CNNCifar100 model into body and head."""
@@ -117,14 +117,9 @@ class CNNCifar100ModelManager(ModelManager):
"""Manager for models with Body/Head split."""
def __init__(self, **kwargs):
- """Initialize the attributes of the model manager.
-
- Args:
- client_id: The id of the client.
- config: Dict containing the configurations to be used by the manager.
- """
+ """Initialize the attributes of the model manager."""
super().__init__(model_split_class=CNNCifar100ModelSplit, **kwargs)
def _create_model(self) -> CNNCifar100:
- """Return CNNCifar100 model to be splitted into head and body."""
- return CNNCifar100().to(self.device)
+ """Return CNNCifar100 model to be split into head and body."""
+ return CNNCifar100()
diff --git a/baselines/fedrep/fedrep/server.py b/baselines/fedrep/fedrep/server.py
deleted file mode 100644
index 5b0c34035ae6..000000000000
--- a/baselines/fedrep/fedrep/server.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Server strategies pipelines for FedRep."""
diff --git a/baselines/fedrep/fedrep/server_app.py b/baselines/fedrep/fedrep/server_app.py
new file mode 100644
index 000000000000..0f98a3bc8876
--- /dev/null
+++ b/baselines/fedrep/fedrep/server_app.py
@@ -0,0 +1,80 @@
+"""fedrep: A Flower Baseline."""
+
+import json
+import os
+import time
+from typing import Dict, List, Tuple
+
+from flwr.common import Context, Metrics, ndarrays_to_parameters
+from flwr.server import ServerApp, ServerAppComponents, ServerConfig
+
+from .utils import get_create_model_fn, get_server_strategy
+
+RESULTS_FILE = "result-{}.json"
+
+
+def config_json_file(context: Context) -> None:
+ """Initialize the json file and write the run configurations."""
+ # Initialize the execution results directory.
+ res_save_path = "./results"
+ if not os.path.exists(res_save_path):
+ os.makedirs(res_save_path)
+ res_save_name = time.strftime("%Y-%m-%d-%H-%M-%S")
+ # Set the date and full path of the file to save the results.
+ global RESULTS_FILE # pylint: disable=global-statement
+ RESULTS_FILE = RESULTS_FILE.format(res_save_name)
+ RESULTS_FILE = f"{res_save_path}/{RESULTS_FILE}"
+ data = {
+ "run_config": dict(context.run_config.items()),
+ "round_res": [],
+ }
+ with open(RESULTS_FILE, "w+", encoding="UTF-8") as fout:
+ json.dump(data, fout, indent=4)
+
+
+def write_res(new_res: Dict[str, float]) -> None:
+ """Load the json file, append result and re-write json collection."""
+ with open(RESULTS_FILE, "r", encoding="UTF-8") as fin:
+ data = json.load(fin)
+ data["round_res"].append(new_res)
+
+ # Write the updated data back to the JSON file
+ with open(RESULTS_FILE, "w", encoding="UTF-8") as fout:
+ json.dump(data, fout, indent=4)
+
+
+def evaluate_metrics_aggregation_fn(eval_metrics: List[Tuple[int, Metrics]]) -> Metrics:
+ """Weighted metrics evaluation."""
+ weights, accuracies, losses = [], [], []
+ for num_examples, metric in eval_metrics:
+ weights.append(num_examples)
+ accuracies.append(float(metric["accuracy"]) * num_examples)
+ losses.append(float(metric["loss"]) * num_examples)
+ accuracy = sum(accuracies) / sum(weights)
+ loss = sum(losses) / sum(weights)
+ write_res({"accuracy": accuracy, "loss": loss})
+ return {"accuracy": accuracy}
+
+
+def server_fn(context: Context) -> ServerAppComponents:
+ """Construct components that set the ServerApp behaviour."""
+ config_json_file(context)
+ # Read from config
+ num_rounds = context.run_config["num-server-rounds"]
+
+ # Initialize model parameters
+ create_model_fn, split_class = get_create_model_fn(context)
+ net = split_class(create_model_fn())
+ parameters = ndarrays_to_parameters(net.get_parameters())
+
+ # Define strategy
+ strategy = get_server_strategy(
+ context=context, params=parameters, eval_fn=evaluate_metrics_aggregation_fn
+ )
+ config = ServerConfig(num_rounds=int(num_rounds))
+
+ return ServerAppComponents(strategy=strategy, config=config)
+
+
+# Create ServerApp
+app = ServerApp(server_fn=server_fn)
diff --git a/baselines/fedrep/fedrep/strategy.py b/baselines/fedrep/fedrep/strategy.py
index 3bee45326a6f..ba4d8b75724f 100644
--- a/baselines/fedrep/fedrep/strategy.py
+++ b/baselines/fedrep/fedrep/strategy.py
@@ -1,4 +1,4 @@
-"""FL server strategies."""
+"""fedrep: A Flower Baseline."""
from flwr.server.strategy import FedAvg
diff --git a/baselines/fedrep/fedrep/utils.py b/baselines/fedrep/fedrep/utils.py
index b706ebf1e041..086abde2c46c 100644
--- a/baselines/fedrep/fedrep/utils.py
+++ b/baselines/fedrep/fedrep/utils.py
@@ -1,204 +1,87 @@
-"""Utility functions for FedRep."""
+"""fedrep: A Flower Baseline."""
-import logging
-import os
-import pickle
-import time
-from pathlib import Path
-from secrets import token_hex
-from typing import Callable, Optional, Type, Union
+from typing import Callable, Type, Union
-import matplotlib.pyplot as plt
-import numpy as np
-from flwr.client import Client
-from flwr.server.history import History
-from omegaconf import DictConfig
+from flwr.common import Context, Parameters
+from flwr.server.strategy import FedAvg
-from fedrep.base_model import get_device
-from fedrep.client import get_client_fn_simulation
-from fedrep.constants import Algorithm
-from fedrep.models import (
+from .constants import Algorithm, ModelDatasetName
+from .models import (
CNNCifar10,
+ CNNCifar10ModelManager,
CNNCifar10ModelSplit,
CNNCifar100,
+ CNNCifar100ModelManager,
CNNCifar100ModelSplit,
)
-
-
-def set_client_state_save_path() -> str:
- """Set the client state save path."""
- client_state_save_path = time.strftime("%Y-%m-%d")
- client_state_sub_path = time.strftime("%H-%M-%S")
- client_state_save_path = (
- f"./client_states/{client_state_save_path}/{client_state_sub_path}"
- )
- if not os.path.exists(client_state_save_path):
- os.makedirs(client_state_save_path)
- return client_state_save_path
-
-
-# pylint: disable=W1202
-def set_client_strategy(cfg: DictConfig) -> DictConfig:
- """Set the client strategy."""
- algorithm = cfg.algorithm.lower()
- if algorithm == Algorithm.FEDREP.value:
- cfg.strategy["_target_"] = "fedrep.strategy.FedRep"
- elif algorithm == Algorithm.FEDAVG.value:
- cfg.strategy["_target_"] = "flwr.server.strategy.FedAvg"
- else:
- logging.warning(
- "Algorithm {} not implemented. Fallback to FedAvg.".format(algorithm)
- )
- return cfg
-
-
-def get_client_fn(
- config: DictConfig, client_state_save_path: str = ""
-) -> Callable[[str], Client]:
- """Get client function."""
- # Get algorithm
- algorithm = config.algorithm.lower()
- # Get client fn
- if algorithm == "fedrep":
- client_fn = get_client_fn_simulation(
- config=config, client_state_save_path=client_state_save_path
- )
- elif algorithm == "fedavg":
- client_fn = get_client_fn_simulation(config=config)
- else:
- raise NotImplementedError
- return client_fn
+from .strategy import FedRep
def get_create_model_fn(
- config: DictConfig,
+ context: Context,
) -> tuple[
- Callable[[], Union[type[CNNCifar10], type[CNNCifar100]]],
- Union[type[CNNCifar10ModelSplit], type[CNNCifar100ModelSplit]],
+ Union[Callable[[], CNNCifar10], Callable[[], CNNCifar100]],
+ Union[Type[CNNCifar10ModelSplit], Type[CNNCifar100ModelSplit]],
]:
"""Get create model function."""
- device = get_device(
- use_cuda=getattr(config, "use_cuda", True),
- specified_device=getattr(config, "specified_device", None),
- )
- split: Union[Type[CNNCifar10ModelSplit], Type[CNNCifar100ModelSplit]] = (
- CNNCifar10ModelSplit
- )
- if config.model_name.lower() == "cnncifar10":
+ model_name = str(context.run_config["model-name"])
+ if model_name == ModelDatasetName.CNN_CIFAR_10.value:
+ split = CNNCifar10ModelSplit
- def create_model() -> Union[Type[CNNCifar10], Type[CNNCifar100]]:
+ def create_model() -> CNNCifar10: # type: ignore
"""Create initial CNNCifar10 model."""
- return CNNCifar10().to(device)
+ return CNNCifar10()
- elif config.model_name.lower() == "cnncifar100":
+ elif model_name == ModelDatasetName.CNN_CIFAR_100.value:
split = CNNCifar100ModelSplit
- def create_model() -> Union[Type[CNNCifar10], Type[CNNCifar100]]:
+ def create_model() -> CNNCifar100: # type: ignore
"""Create initial CNNCifar100 model."""
- return CNNCifar100().to(device)
+ return CNNCifar100()
else:
- raise NotImplementedError("Model not implemented, check name. ")
+ raise NotImplementedError(f"Not a recognized model name {model_name}.")
return create_model, split
-def plot_metric_from_history(
- hist: History, save_plot_path: Path, suffix: Optional[str] = ""
-) -> None:
- """Plot from Flower server History.
-
- Parameters
- ----------
- hist : History
- Object containing evaluation for all rounds.
- save_plot_path : Path
- Folder to save the plot to.
- suffix: Optional[str]
- Optional string to add at the end of the filename for the plot.
- """
- metric_type = "distributed"
- metric_dict = (
- hist.metrics_centralized
- if metric_type == "centralized"
- else hist.metrics_distributed
- )
- try:
- _, values = zip(*metric_dict["accuracy"])
- except KeyError: # If no available metric data
- return
-
- # let's extract decentralized loss (main metric reported in FedProx paper)
- rounds_loss, values_loss = zip(*hist.losses_distributed)
-
- _, axs = plt.subplots(nrows=2, ncols=1, sharex="row")
- axs[0].plot(np.asarray(rounds_loss), np.asarray(values_loss)) # type: ignore
- axs[1].plot(np.asarray(rounds_loss), np.asarray(values)) # type: ignore
-
- axs[0].set_ylabel("Loss") # type: ignore
- axs[1].set_ylabel("Accuracy") # type: ignore
-
- axs[0].grid() # type: ignore
- axs[1].grid() # type: ignore
- # plt.title(f"{metric_type.capitalize()} Validation - MNIST")
- plt.xlabel("Rounds")
- # plt.legend(loc="lower right")
-
- plt.savefig(Path(save_plot_path) / Path(f"{metric_type}_metrics{suffix}.png"))
- plt.close()
-
-
-def save_results_as_pickle(
- history: History,
- file_path: Union[str, Path],
- default_filename: Optional[str] = "results.pkl",
-) -> None:
- """Save results from simulation to pickle.
-
- Parameters
- ----------
- history: History
- History returned by start_simulation.
- file_path: Union[str, Path]
- Path to file to create and store both history and extra_results.
- If path is a directory, the default_filename will be used.
- path doesn't exist, it will be created. If file exists, a
- randomly generated suffix will be added to the file name. This
- is done to avoid overwritting results.
- extra_results : Optional[Dict]
- A dictionary containing additional results you would like
- to be saved to disk. Default: {} (an empty dictionary)
- default_filename: Optional[str]
- File used by default if file_path points to a directory instead
- to a file. Default: "results.pkl"
- """
- path = Path(file_path)
-
- # ensure path exists
- path.mkdir(exist_ok=True, parents=True)
-
- def _add_random_suffix(path_: Path):
- """Add a random suffix to the file name."""
- print(f"File `{path_}` exists! ")
- suffix = token_hex(4)
- print(f"New results to be saved with suffix: {suffix}")
- return path_.parent / (path_.stem + "_" + suffix + ".pkl")
-
- def _complete_path_with_default_name(path_: Path):
- """Append the default file name to the path."""
- print("Using default filename")
- if default_filename is None:
- return path_
- return path_ / default_filename
-
- if path.is_dir():
- path = _complete_path_with_default_name(path)
-
- if path.is_file():
- path = _add_random_suffix(path)
-
- print(f"Results will be saved into: {path}")
- # data = {"history": history, **extra_results}
- data = {"history": history}
- # save results to pickle
- with open(str(path), "wb") as handle:
- pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
+def get_model_manager_class(
+ context: Context,
+) -> Union[Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager]]:
+ """Depending on the model name type return the corresponding model manager."""
+ model_name = str(context.run_config["model-name"])
+ if model_name.lower() == ModelDatasetName.CNN_CIFAR_10.value:
+ model_manager_class = CNNCifar10ModelManager
+ elif model_name.lower() == ModelDatasetName.CNN_CIFAR_100.value:
+ model_manager_class = CNNCifar100ModelManager # type: ignore
+ else:
+ raise NotImplementedError(
+ f"Model {model_name} not implemented, please check model name."
+ )
+ return model_manager_class
+
+
+def get_server_strategy(
+ context: Context, params: Parameters, eval_fn: Callable
+) -> Union[FedAvg, FedRep]:
+ """Define server strategy based on input algorithm."""
+ algorithm = str(context.run_config["algorithm"]).lower()
+ if algorithm == Algorithm.FEDAVG.value:
+ strategy = FedAvg
+ elif algorithm == Algorithm.FEDREP.value:
+ strategy = FedRep
+ else:
+ raise RuntimeError(f"Unknown algorithm {algorithm}.")
+
+ # Read strategy config
+ fraction_fit = float(context.run_config["fraction-fit"])
+ fraction_evaluate = float(context.run_config["fraction-evaluate"])
+ min_available_clients = int(context.run_config["min-available-clients"])
+
+ strategy = strategy(
+ fraction_fit=float(fraction_fit),
+ fraction_evaluate=fraction_evaluate,
+ min_available_clients=min_available_clients,
+ initial_parameters=params,
+ evaluate_metrics_aggregation_fn=eval_fn,
+ ) # type: ignore
+ return strategy # type: ignore
diff --git a/baselines/fedrep/pyproject.toml b/baselines/fedrep/pyproject.toml
index e4c3551af19a..9a45199032ef 100644
--- a/baselines/fedrep/pyproject.toml
+++ b/baselines/fedrep/pyproject.toml
@@ -1,73 +1,39 @@
[build-system]
-requires = ["poetry-core>=1.4.0"]
-build-backend = "poetry.masonry.api"
+requires = ["hatchling"]
+build-backend = "hatchling.build"
-[tool.poetry]
+[project]
name = "fedrep"
version = "1.0.0"
-description = "Exploiting Shared Representations for Personalized Federated Learning"
+description = ""
license = "Apache-2.0"
-authors = ["Jiahao Tan "]
-readme = "README.md"
-homepage = "https://flower.ai"
-repository = "https://github.com/adap/flower"
-documentation = "https://flower.ai"
-classifiers = [
- "Development Status :: 3 - Alpha",
- "Intended Audience :: Developers",
- "Intended Audience :: Science/Research",
- "License :: OSI Approved :: Apache Software License",
- "Operating System :: MacOS :: MacOS X",
- "Operating System :: POSIX :: Linux",
- "Programming Language :: Python",
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: Implementation :: CPython",
- "Topic :: Scientific/Engineering",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Scientific/Engineering :: Mathematics",
- "Topic :: Software Development",
- "Topic :: Software Development :: Libraries",
- "Topic :: Software Development :: Libraries :: Python Modules",
- "Typing :: Typed",
+dependencies = [
+ "flwr[simulation]>=1.13.1",
+ "flwr-datasets[vision]>=0.4.0",
+ "torch==2.2.1",
+ "torchvision==0.17.1",
]
-[tool.poetry.dependencies]
-python = ">=3.10.0, <3.11.0" # don't change this
-flwr = { extras = ["simulation"], version = "1.9.0" }
-hydra-core = "1.3.2" # don't change this
-pandas = "^2.2.2"
-matplotlib = "^3.9.0"
-tqdm = "^4.66.4"
-torch = "^2.2.2"
-torchvision = "^0.17.2"
-setuptools = "<70"
-flwr-datasets = { extras = ["vision"], version = ">=0.3.0" }
-
-[tool.poetry.dev-dependencies]
-isort = "==5.13.2"
-black = "==24.2.0"
-docformatter = "==1.7.5"
-mypy = "==1.4.1"
-pylint = "==2.8.2"
-flake8 = "==3.9.2"
-pytest = "==6.2.4"
-pytest-watch = "==4.2.0"
-ruff = "==0.0.272"
-types-requests = "==2.27.7"
-virtualenv = "==20.21.0"
+[tool.hatch.metadata]
+allow-direct-references = true
+
+[project.optional-dependencies]
+dev = [
+ "isort==5.13.2",
+ "black==24.2.0",
+ "docformatter==1.7.5",
+ "mypy==1.8.0",
+ "pylint==3.2.6",
+ "flake8==5.0.4",
+ "pytest==6.2.4",
+ "pytest-watch==4.2.0",
+ "ruff==0.1.9",
+ "types-requests==2.31.0.20240125",
+]
[tool.isort]
-line_length = 88
-indent = " "
-multi_line_output = 3
-include_trailing_comma = true
-force_grid_wrap = 0
-use_parentheses = true
+profile = "black"
+known_first_party = ["flwr"]
[tool.black]
line-length = 88
@@ -76,7 +42,9 @@ target-version = ["py38", "py39", "py310", "py311"]
[tool.pytest.ini_options]
minversion = "6.2"
addopts = "-qq"
-testpaths = ["flwr_baselines"]
+testpaths = [
+ "flwr_baselines",
+]
[tool.mypy]
ignore_missing_imports = true
@@ -84,14 +52,22 @@ strict = false
plugins = "numpy.typing.mypy_plugin"
[tool.pylint."MESSAGES CONTROL"]
-good-names = "i,j,k,_,x,y,X,Y"
-signature-mutators = "hydra.main.main"
+disable = "duplicate-code,too-few-public-methods,useless-import-alias"
+good-names = "i,j,k,_,x,y,X,Y,K,N"
+max-args = 10
+max-attributes = 15
+max-locals = 36
+max-branches = 20
+max-statements = 55
-[tool.pylint."TYPECHECK"]
+[tool.pylint.typecheck]
generated-members = "numpy.*, torch.*, tensorflow.*"
[[tool.mypy.overrides]]
-module = ["importlib.metadata.*", "importlib_metadata.*"]
+module = [
+ "importlib.metadata.*",
+ "importlib_metadata.*",
+]
follow_imports = "skip"
follow_imports_for_stubs = true
disallow_untyped_calls = false
@@ -137,3 +113,49 @@ exclude = [
[tool.ruff.pydocstyle]
convention = "numpy"
+
+[tool.hatch.build.targets.wheel]
+packages = ["."]
+
+[tool.flwr.app]
+publisher = "dimitris"
+
+[tool.flwr.app.components]
+serverapp = "fedrep.server_app:app"
+clientapp = "fedrep.client_app:app"
+
+[tool.flwr.app.config]
+algorithm = "fedrep"
+
+# dataset specs
+dataset-name = "cifar10"
+dataset-split = "sample"
+dataset-split-num-classes = 2
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
+
+# model specs
+model-name = "cnncifar10"
+batch-size = 50
+learning-rate = 0.01
+momentum = 0.5
+enable-finetune = false
+num-finetune-epochs = 5
+num-local-epochs = 5 # number of local epochs
+num-rep-epochs = 1 # number of representation epochs (only for FedRep)
+
+# server specs
+num-server-rounds = 100
+fraction-fit = 0.1
+fraction-evaluate = 0.1
+min-available-clients = 2
+min-evaluate-clients = 2
+min-fit-clients = 2
+
+[tool.flwr.federations]
+default = "local-sim-100"
+
+[tool.flwr.federations.local-sim-100]
+options.num-supernodes = 100
+options.backend.client-resources.num-cpus = 2
+options.backend.client-resources.num-gpus = 0.5 # GPU fraction allocated to each client