-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of Flower Integration with exareme2: New Controller for Flower Execution: Introduced a controller with dedicated modules designed for managing the Flower workflow and algorithm execution. Added Flower-Compatible Algorithms: Logistic Regression with MIP Data: This algorithm integrates with Flower, utilizing MIP data and allowing parameter customization based on the flow's input data. Logistic Regression with MNIST Data: This version is tailored for Flower but specifically uses MNIST data for operations. Robust Process Management Module: Implemented a new module to enhance process control, including: Functions to send signals and check process status. Management of zombie processes. Process termination with retry capabilities. The controller utilizes this module to initiate, monitor, and terminate Flower execution processes safely, ensuring improved oversight of Flower's client and server components. Lower the local workers on productions tests because the flower made workers and controller heavier. Fixed a problem of version incompatibility of dependencies between pre-commit conf and poetry.
- Loading branch information
1 parent
bb78e1c
commit b75e75f
Showing
68 changed files
with
1,854 additions
and
254 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import json | ||
import os | ||
from pathlib import Path | ||
from typing import List | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
import pymonetdb | ||
import requests | ||
from pydantic import BaseModel | ||
from sklearn import preprocessing | ||
from sklearn.impute import SimpleImputer | ||
|
||
# Constants for project directories and environment configurations | ||
PROJECT_ROOT = Path(__file__).resolve().parents[3] | ||
|
||
|
||
class Inputdata(BaseModel): | ||
data_model: str | ||
datasets: List[str] | ||
filters: Optional[dict] | ||
y: Optional[List[str]] | ||
x: Optional[List[str]] | ||
|
||
|
||
def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame: | ||
return ( | ||
_fetch_data_from_db(data_model, datasets) | ||
if from_db | ||
else _fetch_data_from_csv(data_model, datasets) | ||
) | ||
|
||
|
||
def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame: | ||
query = f'SELECT * FROM "{data_model}"."primary_data"' | ||
conn = pymonetdb.connect( | ||
hostname=os.getenv("MONETDB_IP"), | ||
port=int(os.getenv("MONETDB_PORT")), | ||
username=os.getenv("MONETDB_USERNAME"), | ||
password=os.getenv("MONETDB_PASSWORD"), | ||
database=os.getenv("MONETDB_DB"), | ||
) | ||
df = pd.read_sql(query, conn) | ||
conn.close() | ||
df = df[df["dataset"].isin(datasets)] | ||
return df | ||
|
||
|
||
def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame: | ||
data_folder = ( | ||
PROJECT_ROOT / "tests" / "test_data" / f"{data_model.split(':')[0]}_v_0_1" | ||
) | ||
dataframes = [ | ||
pd.read_csv(data_folder / f"{dataset}.csv") | ||
for dataset in datasets | ||
if (data_folder / f"{dataset}.csv").exists() | ||
] | ||
return pd.concat(dataframes, ignore_index=True) | ||
|
||
|
||
def preprocess_data(inputdata, full_data): | ||
# Ensure x and y are specified and correct | ||
if not inputdata.x or not inputdata.y: | ||
raise ValueError("Input features 'x' and labels 'y' must be specified") | ||
|
||
# Select features and target based on inputdata configuration | ||
features = full_data[inputdata.x] # This should be a DataFrame | ||
target = full_data[inputdata.y].values.ravel() # Flatten the array if it's 2D | ||
|
||
# Impute missing values for features | ||
imputer = SimpleImputer(strategy="most_frequent") | ||
features_imputed = imputer.fit_transform(features) | ||
|
||
# Encode target variable | ||
label_encoder = preprocessing.LabelEncoder() | ||
label_encoder.fit(get_enumerations(inputdata.data_model, inputdata.y[0])) | ||
y_train = label_encoder.transform(target) | ||
|
||
return features_imputed, y_train | ||
|
||
|
||
def post_result(result: dict) -> None: | ||
url = "http://127.0.0.1:5000/flower/result" | ||
headers = {"Content-type": "application/json", "Accept": "text/plain"} | ||
requests.post(url, data=json.dumps(result), headers=headers) | ||
|
||
|
||
def get_input() -> Inputdata: | ||
response = requests.get("http://127.0.0.1:5000/flower/input") | ||
return Inputdata.parse_raw(response.text) | ||
|
||
|
||
def get_enumerations(data_model, variable_name): | ||
response = requests.get("http://127.0.0.1:5000/cdes_metadata") | ||
cdes_metadata = json.loads(response.text) | ||
enumerations = cdes_metadata[data_model][variable_name]["enumerations"] | ||
return [code for code, label in enumerations.items()] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
{ | ||
"name": "logistic_regression", | ||
"desc": "Statistical method. that models the relationship between a dependent binary variable and one or more independent variables by fitting a binary logistic curve to the observed data.", | ||
"label": "Logistic Regression on Flower", | ||
"enabled": true, | ||
"type": "flower", | ||
"inputdata": { | ||
"y": { | ||
"label": "Variable (dependent)", | ||
"desc": "A unique nominal variable. The variable is converted to binary by assigning 1 to the positive class and 0 to all other classes. ", | ||
"types": [ | ||
"int", | ||
"text" | ||
], | ||
"stattypes": [ | ||
"nominal" | ||
], | ||
"notblank": true, | ||
"multiple": false | ||
}, | ||
"x": { | ||
"label": "Covariates (independent)", | ||
"desc": "One or more variables. Can be numerical or nominal. For nominal variables dummy encoding is used.", | ||
"types": [ | ||
"real", | ||
"int", | ||
"text" | ||
], | ||
"stattypes": [ | ||
"numerical", | ||
"nominal" | ||
], | ||
"notblank": true, | ||
"multiple": true | ||
} | ||
} | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
import warnings | ||
|
||
import flwr as fl | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.metrics import log_loss | ||
from utils import get_model_parameters | ||
from utils import set_initial_params | ||
from utils import set_model_params | ||
|
||
from exareme2.algorithms.flower.flower_data_processing import fetch_data | ||
from exareme2.algorithms.flower.flower_data_processing import get_input | ||
from exareme2.algorithms.flower.flower_data_processing import preprocess_data | ||
|
||
|
||
class LogisticRegressionClient(fl.client.NumPyClient): | ||
def __init__(self, model, X_train, y_train): | ||
self.model = model | ||
self.X_train = X_train | ||
self.y_train = y_train | ||
|
||
def get_parameters(self, **kwargs): # Now accepts any keyword arguments | ||
return get_model_parameters(self.model) | ||
|
||
def fit(self, parameters, config): | ||
set_model_params(self.model, parameters) | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
self.model.fit(self.X_train, self.y_train) | ||
return get_model_parameters(self.model), len(self.X_train), {} | ||
|
||
def evaluate(self, parameters, config): | ||
set_model_params(self.model, parameters) | ||
loss = log_loss(self.y_train, self.model.predict_proba(self.X_train)) | ||
accuracy = self.model.score(self.X_train, self.y_train) | ||
return loss, len(self.X_train), {"accuracy": accuracy} | ||
|
||
|
||
if __name__ == "__main__": | ||
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True) | ||
inputdata = get_input() | ||
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True) | ||
X_train, y_train = preprocess_data(inputdata, full_data) | ||
set_initial_params(model, X_train, full_data, inputdata) | ||
|
||
client = LogisticRegressionClient(model, X_train, y_train) | ||
fl.client.start_client( | ||
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
|
||
import flwr as fl | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.metrics import log_loss | ||
from utils import set_initial_params | ||
from utils import set_model_params | ||
|
||
from exareme2.algorithms.flower.flower_data_processing import fetch_data | ||
from exareme2.algorithms.flower.flower_data_processing import get_input | ||
from exareme2.algorithms.flower.flower_data_processing import post_result | ||
from exareme2.algorithms.flower.flower_data_processing import preprocess_data | ||
|
||
# TODO: NUM_OF_ROUNDS should become a parameter of the algorithm and be set on the AlgorithmRequestDTO | ||
NUM_OF_ROUNDS = 5 | ||
|
||
|
||
def fit_round(server_round: int): | ||
"""Configures the next round of training.""" | ||
return {"server_round": server_round} | ||
|
||
|
||
def get_evaluate_fn(model, X_test, y_test): | ||
def evaluate(server_round, parameters, config): | ||
set_model_params(model, parameters) | ||
loss = log_loss(y_test, model.predict_proba(X_test)) | ||
accuracy = model.score(X_test, y_test) | ||
if server_round == NUM_OF_ROUNDS: | ||
post_result({"accuracy": accuracy}) | ||
print({"accuracy": accuracy}) | ||
return loss, {"accuracy": accuracy} | ||
|
||
return evaluate | ||
|
||
|
||
if __name__ == "__main__": | ||
model = LogisticRegression() | ||
inputdata = get_input() | ||
full_data = fetch_data(inputdata.data_model, inputdata.datasets) | ||
X_train, y_train = preprocess_data(inputdata, full_data) | ||
set_initial_params(model, X_train, full_data, inputdata) | ||
strategy = fl.server.strategy.FedAvg( | ||
min_available_clients=int(os.environ["NUMBER_OF_CLIENTS"]), | ||
evaluate_fn=get_evaluate_fn(model, X_train, y_train), | ||
on_fit_config_fn=fit_round, | ||
) | ||
fl.server.start_server( | ||
server_address=os.environ["SERVER_ADDRESS"], | ||
strategy=strategy, | ||
config=fl.server.ServerConfig(num_rounds=NUM_OF_ROUNDS), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import numpy as np | ||
|
||
|
||
def get_model_parameters(model): | ||
params = [model.coef_] | ||
if model.fit_intercept: | ||
params.append(model.intercept_) | ||
return params | ||
|
||
|
||
def set_model_params(model, params): | ||
model.coef_ = params[0] | ||
if model.fit_intercept: | ||
model.intercept_ = params[1] | ||
|
||
|
||
def set_initial_params(model, X_train, full_data, flower_inputdata): | ||
model.classes_ = np.array( | ||
[i for i in range(len(np.unique(full_data[flower_inputdata.y])))] | ||
) | ||
model.coef_ = np.zeros((len(model.classes_), X_train.shape[1])) | ||
if model.fit_intercept: | ||
model.intercept_ = np.zeros((len(model.classes_),)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
{ | ||
"name": "mnist_logistic_regression", | ||
"desc": "Statistical method. that models the relationship between a dependent binary variable and one or more independent variables by fitting a binary logistic curve to the observed data.", | ||
"label": "Logistic Regression on Flower", | ||
"enabled": true, | ||
"type": "flower", | ||
"inputdata": { | ||
"y": { | ||
"label": "Variable (dependent)", | ||
"desc": "A unique nominal variable. The variable is converted to binary by assigning 1 to the positive class and 0 to all other classes. ", | ||
"types": [ | ||
"int", | ||
"text" | ||
], | ||
"stattypes": [ | ||
"nominal" | ||
], | ||
"notblank": true, | ||
"multiple": false | ||
}, | ||
"x": { | ||
"label": "Covariates (independent)", | ||
"desc": "One or more variables. Can be numerical or nominal. For nominal variables dummy encoding is used.", | ||
"types": [ | ||
"real", | ||
"int", | ||
"text" | ||
], | ||
"stattypes": [ | ||
"numerical", | ||
"nominal" | ||
], | ||
"notblank": true, | ||
"multiple": true | ||
} | ||
} | ||
} |
Binary file not shown.
Empty file.
Oops, something went wrong.