Skip to content

Commit

Permalink
Use flwr_datasets for e2e scikit_learn
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Nov 6, 2024
1 parent 24813a6 commit a38443c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 36 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,6 @@ jobs:
- directory: e2e-scikit-learn
e2e: e2e_scikit_learn
dataset: |
import openml
openml.datasets.get_dataset(554)

- directory: e2e-fastai
e2e: e2e_fastai
Expand Down
6 changes: 2 additions & 4 deletions e2e/e2e-scikit-learn/e2e_scikit_learn/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import Context

# Load MNIST dataset from https://www.openml.org/d/554
(X_train, y_train), (X_test, y_test) = utils.load_mnist()

# Split train set into 10 partitions and randomly use one for training.
partition_id = np.random.choice(10)
(X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]
X_train, X_test, y_train, y_test = utils.load_data(partition_id, num_partitions=10)


# Create LogisticRegression Model
model = LogisticRegression(
Expand Down
48 changes: 21 additions & 27 deletions e2e/e2e-scikit-learn/e2e_scikit_learn/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Tuple, Union

import numpy as np
import openml
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from sklearn.linear_model import LogisticRegression

XY = Tuple[np.ndarray, np.ndarray]
Expand Down Expand Up @@ -50,30 +51,23 @@ def set_initial_params(model: LogisticRegression):
model.intercept_ = np.zeros((n_classes,))


def load_mnist() -> Dataset:
"""Loads the MNIST dataset using OpenML.
fds = None # Cache FederatedDataset

OpenML dataset link: https://www.openml.org/d/554
"""
mnist_openml = openml.datasets.get_dataset(554)
Xy, _, _, _ = mnist_openml.get_data(dataset_format="array")
X = Xy[:, :-1] # the last column contains labels
y = Xy[:, -1]
# First 60000 samples consist of the train set
x_train, y_train = X[:1000], y[:1000]
x_test, y_test = X[60000:62000], y[60000:62000]
return (x_train, y_train), (x_test, y_test)


def shuffle(X: np.ndarray, y: np.ndarray) -> XY:
"""Shuffle X and y."""
rng = np.random.default_rng()
idx = rng.permutation(len(X))
return X[idx], y[idx]


def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList:
"""Split X and y into a number of partitions."""
return list(
zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions))
)

def load_data(partition_id: int, num_partitions: int):
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="ylecun/mnist",
partitioners={"train": partitioner},
)

dataset = fds.load_partition(partition_id, "train").with_format("numpy")
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
# Split the on edge data: 80% train, 20% test
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]

return X_train, X_test, y_train, y_test
2 changes: 0 additions & 2 deletions e2e/e2e-scikit-learn/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ authors = [
dependencies = [
"flwr[simulation,rest] @ {root:parent:parent:uri}",
"scikit-learn>=1.1.1,<2.0.0",
"openml>=0.14.0,<0.15.0",
"numpy<2.0.0",
]

[tool.hatch.build.targets.wheel]
Expand Down

0 comments on commit a38443c

Please sign in to comment.