Skip to content

Commit

Permalink
New mipdb version.
Browse files Browse the repository at this point in the history
Update data processing for client so they load data from csv and not the database.
  • Loading branch information
KFilippopolitis committed Jun 13, 2024
1 parent de748c5 commit 94bbbd0
Show file tree
Hide file tree
Showing 25 changed files with 1,610 additions and 1,065 deletions.
40 changes: 15 additions & 25 deletions exareme2/algorithms/flower/flower_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import pandas as pd
import pymonetdb
import requests
from pydantic import BaseModel
from sklearn import preprocessing
Expand All @@ -28,38 +27,29 @@ class Inputdata(BaseModel):
x: Optional[List[str]]


def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame:
return (
_fetch_data_from_db(data_model, datasets)
if from_db
else _fetch_data_from_csv(data_model, datasets)
)
def fetch_client_data(inputdata) -> pd.DataFrame:
dataframes = [
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
for csv_path in os.getenv("CSV_PATHS").split(",")
]
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:
query = f'SELECT * FROM "{data_model}"."primary_data"'
conn = pymonetdb.connect(
hostname=os.getenv("MONETDB_IP"),
port=int(os.getenv("MONETDB_PORT")),
username=os.getenv("MONETDB_USERNAME"),
password=os.getenv("MONETDB_PASSWORD"),
database=os.getenv("MONETDB_DB"),
def fetch_server_data(inputdata) -> pd.DataFrame:
data_folder = Path(
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
)
df = pd.read_sql(query, conn)
conn.close()
df = df[df["dataset"].isin(datasets)]
return df


def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in datasets
for dataset in inputdata.datasets
if (data_folder / f"{dataset}.csv").exists()
]
return pd.concat(dataframes, ignore_index=True)
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def preprocess_data(inputdata, full_data):
Expand Down
23 changes: 18 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
import warnings

import flwr as fl
Expand All @@ -8,7 +9,7 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import fetch_client_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import preprocess_data

Expand Down Expand Up @@ -39,11 +40,23 @@ def evaluate(self, parameters, config):
if __name__ == "__main__":
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True)
full_data = fetch_client_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)

client = LogisticRegressionClient(model, X_train, y_train)
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)
max_retries = 6
for attempt in range(max_retries):
try:
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)
print("Connection successful on attempt", attempt + 1)
break
except Exception as e:
print(f"Connection attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
print("Max retries reached. Exiting.")
raise
4 changes: 2 additions & 2 deletions exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import fetch_server_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
Expand Down Expand Up @@ -36,7 +36,7 @@ def evaluate(server_round, parameters, config):
if __name__ == "__main__":
model = LogisticRegression()
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets)
full_data = fetch_server_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)
strategy = fl.server.strategy.FedAvg(
Expand Down
3 changes: 2 additions & 1 deletion exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,14 @@ def queue_healthcheck_task(
)

def start_flower_client(
self, request_id, algorithm_name, server_address
self, request_id, algorithm_name, server_address, csv_paths
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_client"],
request_id=request_id,
algorithm_name=algorithm_name,
server_address=server_address,
csv_paths=csv_paths,
)

def start_flower_server(
Expand Down
9 changes: 8 additions & 1 deletion exareme2/controller/quart/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ async def get_datasets() -> dict:

@algorithms.route("/datasets_locations", methods=["GET"])
async def get_datasets_locations() -> dict:
return get_worker_landscape_aggregator().get_datasets_locations().datasets_locations
return {
data_model: {
dataset: info.worker_id for dataset, info in datasets_location.items()
}
for data_model, datasets_location in get_worker_landscape_aggregator()
.get_datasets_locations()
.datasets_locations.items()
}


@algorithms.route("/cdes_metadata", methods=["GET"])
Expand Down
24 changes: 12 additions & 12 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from typing import Dict
from typing import List

from exareme2.controller import logger as ctrl_logger
Expand Down Expand Up @@ -52,10 +53,17 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
workers_info = self._get_workers_info_by_dataset(
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
)

workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
]
task_handlers = [
self._create_worker_tasks_handler(request_id, worker)
for worker in workers_info
Expand All @@ -80,14 +88,15 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
server_pid = None
clients_pids = {}
server_address = f"{server_ip}:{FLOWER_SERVER_PORT}"

try:
server_pid = server_task_handler.start_flower_server(
algorithm_name, len(task_handlers), str(server_address)
)
clients_pids = {
handler.start_flower_client(
algorithm_name, str(server_address)
algorithm_name,
str(server_address),
csv_paths_per_worker_id[handler.worker_id],
): handler
for handler in task_handlers
}
Expand Down Expand Up @@ -130,12 +139,3 @@ async def _cleanup(

def _get_workers_info_by_dataset(self, data_model, datasets) -> List[WorkerInfo]:
"""Retrieves worker information for those handling the specified datasets."""
worker_ids = (
self.worker_landscape_aggregator.get_worker_ids_with_any_of_datasets(
data_model, datasets
)
)
return [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in worker_ids
]
4 changes: 2 additions & 2 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def worker_id(self) -> str:
def worker_data_address(self) -> str:
return self._db_address

def start_flower_client(self, algorithm_name, server_address) -> int:
def start_flower_client(self, algorithm_name, server_address, csv_paths) -> int:
return self._worker_tasks_handler.start_flower_client(
self._request_id, algorithm_name, server_address
self._request_id, algorithm_name, server_address, csv_paths
).get(timeout=self._tasks_timeout)

def start_flower_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from exareme2.controller.celery.tasks_handler import WorkerTasksHandler
from exareme2.worker_communication import CommonDataElements
from exareme2.worker_communication import DataModelAttributes
from exareme2.worker_communication import DatasetInfo
from exareme2.worker_communication import DatasetsInfoPerDataModel
from exareme2.worker_communication import WorkerInfo


Expand All @@ -23,10 +25,11 @@ def get_worker_info_task(self) -> WorkerInfo:
).get(self._tasks_timeout)
return WorkerInfo.parse_raw(result)

def get_worker_datasets_per_data_model_task(self) -> Dict[str, Dict[str, str]]:
return self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
def get_worker_datasets_per_data_model_task(self) -> DatasetsInfoPerDataModel:
result = self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
self._request_id
).get(self._tasks_timeout)
return DatasetsInfoPerDataModel.parse_raw(result)

def get_data_model_cdes_task(self, data_model: str) -> CommonDataElements:
result = self._worker_tasks_handler.queue_data_model_cdes_task(
Expand Down
Loading

0 comments on commit 94bbbd0

Please sign in to comment.