Skip to content

Commit

Permalink
Merge branch 'main' into fds-rewrite-iid-partitioner
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Nov 15, 2023
2 parents 4316466 + e2116b0 commit 0d2c7bb
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 6 deletions.
2 changes: 2 additions & 0 deletions baselines/flwr_baselines/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ matplotlib = "^3.5.1"
scikit-image = "^0.18.1"
scikit-learn = "^1.2.1"
wget = "^3.2"
virtualenv = "^20.24.6"
pandas = "^1.5.3"
pyhamcrest = "^2.0.4"

Expand All @@ -61,6 +62,7 @@ flake8 = "==3.9.2"
pytest = "==6.2.4"
pytest-watch = "==4.2.0"
types-requests = "==2.27.7"
pydantic = "==2.4.2"

[tool.isort]
line_length = 88
Expand Down
2 changes: 2 additions & 0 deletions baselines/flwr_baselines/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ matplotlib >= 3.5.0
scikit-image >= 0.18.1
scikit-learn >= 0.24.2
wget >= 3.2
virtualenv >= 20.24.6

##### dev-dependencies
isort == 5.11.5
Expand All @@ -26,3 +27,4 @@ flake8 == 3.9.2
pytest == 6.2.4
pytest-watch == 4.2.0
types-requests == 2.27.7
pydantic ==2.4.2
15 changes: 15 additions & 0 deletions datasets/e2e/scikit-learn/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "fds-e2e-sklearn"
version = "0.1.0"
description = "Flower Datasets with scikit-learn"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
flwr-datasets = { path = "./../../", extras = ["vision"] }
scikit-learn = "^1.2.0"
parameterized = "==0.9.0"
94 changes: 94 additions & 0 deletions datasets/e2e/scikit-learn/sklearn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import unittest

import numpy as np
from parameterized import parameterized_class
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from flwr_datasets import FederatedDataset


# Using parameterized testing, two different sets of preprocessing:
# 1. Without scaling.
# 2. With standard scaling.
@parameterized_class(
[
{"dataset_name": "mnist", "preprocessing": None},
{"dataset_name": "mnist", "preprocessing": StandardScaler()},
]
)
class FdsWithSKLearn(unittest.TestCase):
"""Test Flower Datasets with Scikit-learn's Logistic Regression."""

dataset_name = ""
preprocessing = None

def _get_partition_data(self):
"""Retrieve partition data."""
partition_id = 0
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10})
partition = fds.load_partition(partition_id, "train")
partition.set_format("numpy")
partition_train_test = partition.train_test_split(test_size=0.2)
X_train, y_train = partition_train_test["train"]["image"], partition_train_test[
"train"]["label"]
X_test, y_test = partition_train_test["test"]["image"], partition_train_test[
"test"]["label"]
X_train = X_train.reshape(-1, 28 * 28)
X_test = X_test.reshape(-1, 28 * 28)
if self.preprocessing:
self.preprocessing.fit(X_train)
X_train = self.preprocessing.transform(X_train)
X_test = self.preprocessing.transform(X_test)

return X_train, X_test, y_train, y_test

def test_data_shape(self):
"""Test if the data shape is maintained after preprocessing."""
X_train, _, _, _ = self._get_partition_data()
self.assertEqual(X_train.shape, (4_800, 28 * 28))

def test_X_train_type(self):
"""Test if the data type is correct."""
X_train, _, _, _ = self._get_partition_data()
self.assertIsInstance(X_train, np.ndarray)

def test_y_train_type(self):
"""Test if the data type is correct."""
_, _, y_train, _ = self._get_partition_data()
self.assertIsInstance(y_train, np.ndarray)

def test_X_test_type(self):
"""Test if the data type is correct."""
_, X_test, _, _ = self._get_partition_data()
self.assertIsInstance(X_test, np.ndarray)

def test_y_test_type(self):
"""Test if the data type is correct."""
_, _, _, y_test = self._get_partition_data()
self.assertIsInstance(y_test, np.ndarray)

def test_train_classifier(self):
"""Test if the classifier trains without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
try:
clf = LogisticRegression()
clf.fit(X_train, y_train)
except Exception as e:
self.fail(f"Fitting Logistic Regression raised {type(e)} unexpectedly!")

def test_predict_from_classifier(self):
"""Test if the classifier predicts without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
clf = LogisticRegression()
clf.fit(X_train, y_train)
try:
_ = clf.predict(X_test)
except Exception as e:
self.fail(
f"Predicting using Logistic Regression model raised {type(e)} "
f"unexpectedly!")


if __name__ == '__main__':
unittest.main()
5 changes: 1 addition & 4 deletions examples/whisper-federated-finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ python sim.py # append --num_gpus=0 if you don't have GPUs on your system

# Once finished centralised evaluation loss/acc metrics will be shown

INFO flwr 2023-11-08 14:03:57,557 | app.py:229 | app_fit: metrics_centralized {'val_accuracy': [(0, 0.03977158885994791),
(1, 0.6940492887196954), (2, 0.5969745541975556), (3, 0.8794830695251452), (4, 0.9021238228811861), (5, 0.8943097575636145),
(6, 0.9047285113203767), (7, 0.9330795431777199), (8, 0.9446002805049089), (9, 0.9556201162091765)],
'test_accuracy': [(10, 0.9719836400817996)]}
INFO flwr 2023-11-08 14:03:57,557 | app.py:229 | app_fit: metrics_centralized {'val_accuracy': [(0, 0.03977158885994791), (1, 0.6940492887196954), (2, 0.5969745541975556), (3, 0.8794830695251452), (4, 0.9021238228811861), (5, 0.8943097575636145), (6, 0.9047285113203767), (7, 0.9330795431777199), (8, 0.9446002805049089), (9, 0.9556201162091765)], 'test_accuracy': [(10, 0.9719836400817996)]}
```

![Global validation accuracy FL with Whisper model](_static/whisper_flower_acc.png)
Expand Down
11 changes: 9 additions & 2 deletions src/cc/flwr/src/grpc_rere.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ std::optional<flwr::proto::Node> get_node_from_store() {
return node->second;
}

void delete_node_from_store() {
std::lock_guard<std::mutex> lock(node_store_mutex);
auto node = node_store.find(KEY_NODE);
if (node == node_store.end() || !node->second.has_value()) {
node_store.erase(node);
}
}

std::optional<flwr::proto::TaskIns> get_current_task_ins() {
std::lock_guard<std::mutex> state_lock(state_mutex);
auto current_task_ins = state.find(KEY_TASK_INS);
Expand Down Expand Up @@ -80,8 +88,7 @@ void delete_node(const std::unique_ptr<flwr::proto::Fleet::Stub> &stub) {
delete_node_request.release_node(); // Release if status is ok
}

// TODO: Check if Node needs to be removed from local map
// node_store.erase(node);
delete_node_from_store();
}

std::optional<flwr::proto::TaskIns>
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def delete_node() -> None:
delete_node_request = DeleteNodeRequest(node=node)
stub.DeleteNode(request=delete_node_request)

del node_store[KEY_NODE]

def receive() -> Optional[TaskIns]:
"""Receive next task from server."""
# Get Node
Expand Down

0 comments on commit 0d2c7bb

Please sign in to comment.