Skip to content

Commit

Permalink
temporary
Browse files Browse the repository at this point in the history
  • Loading branch information
KFilippopolitis committed Jul 2, 2024
1 parent 7d30ce8 commit 920e911
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 58 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/algorithm_validation_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ jobs:
with:
run: cat /tmp/exareme2/localworker1.out

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5

- name: Run Exareme2 algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/exareme2/ --verbosity=4 -n 16 -k "input1 and not input1-" # run tests 10-19

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import pandas as pd
import pymonetdb
import requests
from flwr.common.logger import FLOWER_LOGGER
from pydantic import BaseModel
Expand All @@ -29,37 +28,29 @@ class Inputdata(BaseModel):
x: Optional[List[str]]


def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame:
return (
_fetch_data_from_db(data_model, datasets)
if from_db
else _fetch_data_from_csv(data_model, datasets)
)
def fetch_client_data(inputdata) -> pd.DataFrame:
dataframes = [
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
for csv_path in os.getenv("CSV_PATHS").split(",")
]
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:
query = f'SELECT * FROM "{data_model}"."primary_data"'
conn = pymonetdb.connect(
hostname=os.getenv("MONETDB_IP"),
port=int(os.getenv("MONETDB_PORT")),
username=os.getenv("MONETDB_USERNAME"),
password=os.getenv("MONETDB_PASSWORD"),
database=os.getenv("MONETDB_DB"),
def fetch_server_data(inputdata) -> pd.DataFrame:
data_folder = Path(
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
)
df = pd.read_sql(query, conn)
conn.close()
df = df[df["dataset"].isin(datasets)]
return df


def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in datasets
for dataset in inputdata.datasets
if (data_folder / f"{dataset}.csv").exists()
]
return pd.concat(dataframes, ignore_index=True)
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def preprocess_data(inputdata, full_data):
Expand Down
10 changes: 5 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_client_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data


class LogisticRegressionClient(fl.client.NumPyClient):
Expand Down Expand Up @@ -42,7 +42,7 @@ def evaluate(self, parameters, config):
if __name__ == "__main__":
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True)
full_data = fetch_client_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)

Expand All @@ -59,7 +59,7 @@ def evaluate(self, parameters, config):
break
except Exception as e:
FLOWER_LOGGER.warning(
f"Connection with the server failed. Attempt {attempts} failed: {e}"
f"Connection with the server failed. Attempt {attempts + 1} failed: {e}"
)
time.sleep(pow(2, attempts))
attempts += 1
Expand Down
10 changes: 5 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_server_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data

# TODO: NUM_OF_ROUNDS should become a parameter of the algorithm and be set on the AlgorithmRequestDTO
NUM_OF_ROUNDS = 5
Expand All @@ -35,7 +35,7 @@ def evaluate(server_round, parameters, config):
if __name__ == "__main__":
model = LogisticRegression()
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets)
full_data = fetch_server_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)
strategy = fl.server.strategy.FedAvg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.mnist_logistic_regression import utils

NUM_OF_ROUNDS = 5
Expand Down
3 changes: 2 additions & 1 deletion exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,14 @@ def queue_healthcheck_task(
)

def start_flower_client(
self, request_id, algorithm_name, server_address, execution_timeout
self, request_id, algorithm_name, server_address, csv_paths, execution_timeout
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_client"],
request_id=request_id,
algorithm_name=algorithm_name,
server_address=server_address,
csv_paths=csv_paths,
execution_timeout=execution_timeout,
)

Expand Down
23 changes: 10 additions & 13 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from typing import Dict
from typing import List

from exareme2.controller import config as ctrl_config
Expand Down Expand Up @@ -53,10 +54,17 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
workers_info = self._get_workers_info_by_dataset(
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
)

workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
]
task_handlers = [
self._create_worker_tasks_handler(request_id, worker)
for worker in workers_info
Expand Down Expand Up @@ -90,6 +98,7 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
handler.start_flower_client(
algorithm_name,
str(server_address),
csv_paths_per_worker_id[handler.worker_id],
ctrl_config.flower_execution_timeout,
): handler
for handler in task_handlers
Expand Down Expand Up @@ -130,15 +139,3 @@ async def _cleanup(
server_task_handler.stop_flower_server(server_pid, algorithm_name)
for pid, handler in clients_pids.items():
handler.stop_flower_client(pid, algorithm_name)

def _get_workers_info_by_dataset(self, data_model, datasets) -> List[WorkerInfo]:
"""Retrieves worker information for those handling the specified datasets."""
worker_ids = (
self.worker_landscape_aggregator.get_worker_ids_with_any_of_datasets(
data_model, datasets
)
)
return [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in worker_ids
]
8 changes: 6 additions & 2 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ def worker_data_address(self) -> str:
return self._db_address

def start_flower_client(
self, algorithm_name, server_address, execution_timeout
self, algorithm_name, server_address, csv_paths, execution_timeout
) -> int:
return self._worker_tasks_handler.start_flower_client(
self._request_id, algorithm_name, server_address, execution_timeout
self._request_id,
algorithm_name,
server_address,
csv_paths,
execution_timeout,
).get(timeout=self._tasks_timeout)

def start_flower_server(
Expand Down
4 changes: 2 additions & 2 deletions exareme2/worker/flower/starter/starter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

@shared_task
def start_flower_client(
request_id: str, algorithm_name, server_address, execution_timeout
request_id: str, algorithm_name, server_address, csv_paths, execution_timeout
) -> int:
return starter_service.start_flower_client(
request_id, algorithm_name, server_address, execution_timeout
request_id, algorithm_name, server_address, csv_paths, execution_timeout
)


Expand Down
3 changes: 2 additions & 1 deletion exareme2/worker/flower/starter/starter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@initialise_logger
def start_flower_client(
request_id: str, algorithm_name, server_address, execution_timeout
request_id: str, algorithm_name, server_address, csv_paths, execution_timeout
) -> int:
env_vars = {
"MONETDB_IP": worker_config.monetdb.ip,
Expand All @@ -22,6 +22,7 @@ def start_flower_client(
"CONTROLLER_IP": worker_config.controller.ip,
"CONTROLLER_PORT": worker_config.controller.port,
"DATA_PATH": worker_config.data_path,
"CSV_PATHS": ",".join(csv_paths),
"TIMEOUT": execution_timeout,
}
process = FlowerProcess(f"{algorithm_name}/client.py", env_vars=env_vars)
Expand Down

0 comments on commit 920e911

Please sign in to comment.