Skip to content

Commit

Permalink
Update data processing for client, so they load data from csv and no…
Browse files Browse the repository at this point in the history
…t from the database
  • Loading branch information
Kostas Filippopolitis committed Jul 9, 2024
1 parent 7b502bc commit 2db7dcf
Show file tree
Hide file tree
Showing 28 changed files with 1,763 additions and 1,247 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/algorithm_validation_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ jobs:
with:
run: cat /tmp/exareme2/localworker1.out

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5

- name: Run Exareme2 algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/exareme2/ --verbosity=4 -n 16 -k "input1 and not input1-" # run tests 10-19

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import pandas as pd
import pymonetdb
import requests
from flwr.common.logger import FLOWER_LOGGER
from pydantic import BaseModel
Expand All @@ -29,37 +28,30 @@ class Inputdata(BaseModel):
x: Optional[List[str]]


def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame:
return (
_fetch_data_from_db(data_model, datasets)
if from_db
else _fetch_data_from_csv(data_model, datasets)
)
def fetch_client_data(inputdata) -> pd.DataFrame:
FLOWER_LOGGER.error(f"BROOO {os.getenv('CSV_PATHS')}")
dataframes = [
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
for csv_path in os.getenv("CSV_PATHS").split(",")
]
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:
query = f'SELECT * FROM "{data_model}"."primary_data"'
conn = pymonetdb.connect(
hostname=os.getenv("MONETDB_IP"),
port=int(os.getenv("MONETDB_PORT")),
username=os.getenv("MONETDB_USERNAME"),
password=os.getenv("MONETDB_PASSWORD"),
database=os.getenv("MONETDB_DB"),
def fetch_server_data(inputdata) -> pd.DataFrame:
data_folder = Path(
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
)
df = pd.read_sql(query, conn)
conn.close()
df = df[df["dataset"].isin(datasets)]
return df


def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in datasets
for dataset in inputdata.datasets
if (data_folder / f"{dataset}.csv").exists()
]
return pd.concat(dataframes, ignore_index=True)
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def preprocess_data(inputdata, full_data):
Expand Down
10 changes: 5 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_client_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data


class LogisticRegressionClient(fl.client.NumPyClient):
Expand Down Expand Up @@ -42,7 +42,7 @@ def evaluate(self, parameters, config):
if __name__ == "__main__":
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True)
full_data = fetch_client_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)

Expand All @@ -59,7 +59,7 @@ def evaluate(self, parameters, config):
break
except Exception as e:
FLOWER_LOGGER.warning(
f"Connection with the server failed. Attempt {attempts} failed: {e}"
f"Connection with the server failed. Attempt {attempts + 1} failed: {e}"
)
time.sleep(pow(2, attempts))
attempts += 1
Expand Down
10 changes: 5 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_server_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data

# TODO: NUM_OF_ROUNDS should become a parameter of the algorithm and be set on the AlgorithmRequestDTO
NUM_OF_ROUNDS = 5
Expand All @@ -35,7 +35,7 @@ def evaluate(server_round, parameters, config):
if __name__ == "__main__":
model = LogisticRegression()
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets)
full_data = fetch_server_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)
strategy = fl.server.strategy.FedAvg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.mnist_logistic_regression import utils

NUM_OF_ROUNDS = 5
Expand Down
3 changes: 2 additions & 1 deletion exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,14 @@ def queue_healthcheck_task(
)

def start_flower_client(
self, request_id, algorithm_name, server_address, execution_timeout
self, request_id, algorithm_name, server_address, csv_paths, execution_timeout
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_client"],
request_id=request_id,
algorithm_name=algorithm_name,
server_address=server_address,
csv_paths=csv_paths,
execution_timeout=execution_timeout,
)

Expand Down
9 changes: 8 additions & 1 deletion exareme2/controller/quart/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ async def get_datasets() -> dict:

@algorithms.route("/datasets_locations", methods=["GET"])
async def get_datasets_locations() -> dict:
return get_worker_landscape_aggregator().get_datasets_locations().datasets_locations
return {
data_model: {
dataset: info.worker_id for dataset, info in datasets_location.items()
}
for data_model, datasets_location in get_worker_landscape_aggregator()
.get_datasets_locations()
.datasets_locations.items()
}


@algorithms.route("/cdes_metadata", methods=["GET"])
Expand Down
23 changes: 10 additions & 13 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import warnings
from typing import Dict
from typing import List

from exareme2.controller import config as ctrl_config
Expand Down Expand Up @@ -53,10 +55,16 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
workers_info = self._get_workers_info_by_dataset(
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
)
workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
]
task_handlers = [
self._create_worker_tasks_handler(request_id, worker)
for worker in workers_info
Expand Down Expand Up @@ -90,6 +98,7 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
handler.start_flower_client(
algorithm_name,
str(server_address),
csv_paths_per_worker_id[handler.worker_id],
ctrl_config.flower_execution_timeout,
): handler
for handler in task_handlers
Expand Down Expand Up @@ -130,15 +139,3 @@ async def _cleanup(
server_task_handler.stop_flower_server(server_pid, algorithm_name)
for pid, handler in clients_pids.items():
handler.stop_flower_client(pid, algorithm_name)

def _get_workers_info_by_dataset(self, data_model, datasets) -> List[WorkerInfo]:
"""Retrieves worker information for those handling the specified datasets."""
worker_ids = (
self.worker_landscape_aggregator.get_worker_ids_with_any_of_datasets(
data_model, datasets
)
)
return [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in worker_ids
]
8 changes: 6 additions & 2 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ def worker_data_address(self) -> str:
return self._db_address

def start_flower_client(
self, algorithm_name, server_address, execution_timeout
self, algorithm_name, server_address, csv_paths, execution_timeout
) -> int:
return self._worker_tasks_handler.start_flower_client(
self._request_id, algorithm_name, server_address, execution_timeout
self._request_id,
algorithm_name,
server_address,
csv_paths,
execution_timeout,
).get(timeout=self._tasks_timeout)

def start_flower_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from exareme2.controller.celery.tasks_handler import WorkerTasksHandler
from exareme2.worker_communication import CommonDataElements
from exareme2.worker_communication import DataModelAttributes
from exareme2.worker_communication import DatasetsInfoPerDataModel
from exareme2.worker_communication import WorkerInfo


Expand All @@ -23,10 +24,11 @@ def get_worker_info_task(self) -> WorkerInfo:
).get(self._tasks_timeout)
return WorkerInfo.parse_raw(result)

def get_worker_datasets_per_data_model_task(self) -> Dict[str, Dict[str, str]]:
return self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
def get_worker_datasets_per_data_model_task(self) -> DatasetsInfoPerDataModel:
result = self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
self._request_id
).get(self._tasks_timeout)
return DatasetsInfoPerDataModel.parse_raw(result)

def get_data_model_cdes_task(self, data_model: str) -> CommonDataElements:
result = self._worker_tasks_handler.queue_data_model_cdes_task(
Expand Down
Loading

0 comments on commit 2db7dcf

Please sign in to comment.