Skip to content

Commit

Permalink
Update data processing for client, so they load data from csv and not…
Browse files Browse the repository at this point in the history
… the database.
  • Loading branch information
KFilippopolitis committed Jul 1, 2024
1 parent 9dde720 commit 97c0a64
Show file tree
Hide file tree
Showing 22 changed files with 1,022 additions and 559 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import pandas as pd
import pymonetdb
import requests
from flwr.common.logger import FLOWER_LOGGER
from pydantic import BaseModel
Expand All @@ -29,37 +28,29 @@ class Inputdata(BaseModel):
x: Optional[List[str]]


def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame:
return (
_fetch_data_from_db(data_model, datasets)
if from_db
else _fetch_data_from_csv(data_model, datasets)
)
def fetch_client_data(inputdata) -> pd.DataFrame:
dataframes = [
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
for csv_path in os.getenv("CSV_PATHS").split(",")
]
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:
query = f'SELECT * FROM "{data_model}"."primary_data"'
conn = pymonetdb.connect(
hostname=os.getenv("MONETDB_IP"),
port=int(os.getenv("MONETDB_PORT")),
username=os.getenv("MONETDB_USERNAME"),
password=os.getenv("MONETDB_PASSWORD"),
database=os.getenv("MONETDB_DB"),
def fetch_server_data(inputdata) -> pd.DataFrame:
data_folder = Path(
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
)
df = pd.read_sql(query, conn)
conn.close()
df = df[df["dataset"].isin(datasets)]
return df


def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in datasets
for dataset in inputdata.datasets
if (data_folder / f"{dataset}.csv").exists()
]
return pd.concat(dataframes, ignore_index=True)
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def preprocess_data(inputdata, full_data):
Expand Down
8 changes: 4 additions & 4 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_client_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data


class LogisticRegressionClient(fl.client.NumPyClient):
Expand Down Expand Up @@ -42,7 +42,7 @@ def evaluate(self, parameters, config):
if __name__ == "__main__":
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True)
full_data = fetch_client_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)

Expand Down
10 changes: 5 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_server_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data

# TODO: NUM_OF_ROUNDS should become a parameter of the algorithm and be set on the AlgorithmRequestDTO
NUM_OF_ROUNDS = 5
Expand All @@ -35,7 +35,7 @@ def evaluate(server_round, parameters, config):
if __name__ == "__main__":
model = LogisticRegression()
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets)
full_data = fetch_server_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)
strategy = fl.server.strategy.FedAvg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.mnist_logistic_regression import utils

NUM_OF_ROUNDS = 5
Expand Down
3 changes: 2 additions & 1 deletion exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,14 @@ def queue_healthcheck_task(
)

def start_flower_client(
self, request_id, algorithm_name, server_address, execution_timeout
self, request_id, algorithm_name, server_address, csv_paths, execution_timeout
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_client"],
request_id=request_id,
algorithm_name=algorithm_name,
server_address=server_address,
csv_paths=csv_paths,
execution_timeout=execution_timeout,
)

Expand Down
9 changes: 8 additions & 1 deletion exareme2/controller/quart/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ async def get_datasets() -> dict:

@algorithms.route("/datasets_locations", methods=["GET"])
async def get_datasets_locations() -> dict:
return get_worker_landscape_aggregator().get_datasets_locations().datasets_locations
return {
data_model: {
dataset: info.worker_id for dataset, info in datasets_location.items()
}
for data_model, datasets_location in get_worker_landscape_aggregator()
.get_datasets_locations()
.datasets_locations.items()
}


@algorithms.route("/cdes_metadata", methods=["GET"])
Expand Down
24 changes: 10 additions & 14 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from typing import Dict
from typing import List

from exareme2.controller import config as ctrl_config
Expand Down Expand Up @@ -53,10 +54,17 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
workers_info = self._get_workers_info_by_dataset(
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
)

workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
]
task_handlers = [
self._create_worker_tasks_handler(request_id, worker)
for worker in workers_info
Expand All @@ -81,7 +89,6 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
server_pid = None
clients_pids = {}
server_address = f"{server_ip}:{FLOWER_SERVER_PORT}"

try:
server_pid = server_task_handler.start_flower_server(
algorithm_name, len(task_handlers), str(server_address)
Expand All @@ -90,6 +97,7 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
handler.start_flower_client(
algorithm_name,
str(server_address),
csv_paths_per_worker_id[handler.worker_id],
ctrl_config.flower_execution_timeout,
): handler
for handler in task_handlers
Expand Down Expand Up @@ -130,15 +138,3 @@ async def _cleanup(
server_task_handler.stop_flower_server(server_pid, algorithm_name)
for pid, handler in clients_pids.items():
handler.stop_flower_client(pid, algorithm_name)

def _get_workers_info_by_dataset(self, data_model, datasets) -> List[WorkerInfo]:
"""Retrieves worker information for those handling the specified datasets."""
worker_ids = (
self.worker_landscape_aggregator.get_worker_ids_with_any_of_datasets(
data_model, datasets
)
)
return [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in worker_ids
]
8 changes: 6 additions & 2 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ def worker_data_address(self) -> str:
return self._db_address

def start_flower_client(
self, algorithm_name, server_address, execution_timeout
self, algorithm_name, server_address, csv_paths, execution_timeout
) -> int:
return self._worker_tasks_handler.start_flower_client(
self._request_id, algorithm_name, server_address, execution_timeout
self._request_id,
algorithm_name,
server_address,
csv_paths,
execution_timeout,
).get(timeout=self._tasks_timeout)

def start_flower_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from exareme2.controller.celery.tasks_handler import WorkerTasksHandler
from exareme2.worker_communication import CommonDataElements
from exareme2.worker_communication import DataModelAttributes
from exareme2.worker_communication import DatasetsInfoPerDataModel
from exareme2.worker_communication import WorkerInfo


Expand All @@ -23,10 +24,11 @@ def get_worker_info_task(self) -> WorkerInfo:
).get(self._tasks_timeout)
return WorkerInfo.parse_raw(result)

def get_worker_datasets_per_data_model_task(self) -> Dict[str, Dict[str, str]]:
return self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
def get_worker_datasets_per_data_model_task(self) -> DatasetsInfoPerDataModel:
result = self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
self._request_id
).get(self._tasks_timeout)
return DatasetsInfoPerDataModel.parse_raw(result)

def get_data_model_cdes_task(self, data_model: str) -> CommonDataElements:
result = self._worker_tasks_handler.queue_data_model_cdes_task(
Expand Down
Loading

0 comments on commit 97c0a64

Please sign in to comment.