Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client load data directly from csvs. #484

Merged
merged 2 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/algorithm_validation_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ jobs:
with:
run: cat /tmp/exareme2/localworker1.out

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5

- name: Run Exareme2 algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/exareme2/ --verbosity=4 -n 16 -k "input1 and not input1-" # run tests 10-19

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import pandas as pd
import pymonetdb
import requests
from flwr.common.logger import FLOWER_LOGGER
from pydantic import BaseModel
Expand All @@ -29,37 +28,30 @@ class Inputdata(BaseModel):
x: Optional[List[str]]


def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame:
return (
_fetch_data_from_db(data_model, datasets)
if from_db
else _fetch_data_from_csv(data_model, datasets)
)
def fetch_client_data(inputdata) -> pd.DataFrame:
FLOWER_LOGGER.error(f"BROOO {os.getenv('CSV_PATHS')}")
dataframes = [
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
for csv_path in os.getenv("CSV_PATHS").split(",")
]
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:
query = f'SELECT * FROM "{data_model}"."primary_data"'
conn = pymonetdb.connect(
hostname=os.getenv("MONETDB_IP"),
port=int(os.getenv("MONETDB_PORT")),
username=os.getenv("MONETDB_USERNAME"),
password=os.getenv("MONETDB_PASSWORD"),
database=os.getenv("MONETDB_DB"),
def fetch_server_data(inputdata) -> pd.DataFrame:
data_folder = Path(
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
)
df = pd.read_sql(query, conn)
conn.close()
df = df[df["dataset"].isin(datasets)]
return df


def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in datasets
for dataset in inputdata.datasets
if (data_folder / f"{dataset}.csv").exists()
]
return pd.concat(dataframes, ignore_index=True)
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def preprocess_data(inputdata, full_data):
Expand Down
33 changes: 26 additions & 7 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import os
import time
import warnings
from math import log2

import flwr as fl
from flwr.common.logger import FLOWER_LOGGER
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from utils import get_model_parameters
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_client_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data


class LogisticRegressionClient(fl.client.NumPyClient):
Expand Down Expand Up @@ -39,11 +42,27 @@ def evaluate(self, parameters, config):
if __name__ == "__main__":
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True)
full_data = fetch_client_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)

client = LogisticRegressionClient(model, X_train, y_train)
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)

attempts = 0
max_attempts = int(log2(int(os.environ["TIMEOUT"])))
while True:
try:
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)
FLOWER_LOGGER.debug("Connection successful on attempt", attempts + 1)
break
except Exception as e:
FLOWER_LOGGER.warning(
f"Connection with the server failed. Attempt {attempts + 1} failed: {e}"
)
time.sleep(pow(2, attempts))
attempts += 1
if attempts >= max_attempts:
FLOWER_LOGGER.error("Could not establish connection to the server.")
raise e
10 changes: 5 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_server_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data

# TODO: NUM_OF_ROUNDS should become a parameter of the algorithm and be set on the AlgorithmRequestDTO
NUM_OF_ROUNDS = 5
Expand All @@ -35,7 +35,7 @@ def evaluate(server_round, parameters, config):
if __name__ == "__main__":
model = LogisticRegression()
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets)
full_data = fetch_server_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)
strategy = fl.server.strategy.FedAvg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.mnist_logistic_regression import utils

NUM_OF_ROUNDS = 5
Expand Down
4 changes: 3 additions & 1 deletion exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,15 @@ def queue_healthcheck_task(
)

def start_flower_client(
self, request_id, algorithm_name, server_address
self, request_id, algorithm_name, server_address, csv_paths, execution_timeout
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_client"],
request_id=request_id,
algorithm_name=algorithm_name,
server_address=server_address,
csv_paths=csv_paths,
execution_timeout=execution_timeout,
)

def start_flower_server(
Expand Down
9 changes: 8 additions & 1 deletion exareme2/controller/quart/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ async def get_datasets() -> dict:

@algorithms.route("/datasets_locations", methods=["GET"])
async def get_datasets_locations() -> dict:
return get_worker_landscape_aggregator().get_datasets_locations().datasets_locations
return {
data_model: {
dataset: info.worker_id for dataset, info in datasets_location.items()
}
for data_model, datasets_location in get_worker_landscape_aggregator()
.get_datasets_locations()
.datasets_locations.items()
}


@algorithms.route("/cdes_metadata", methods=["GET"])
Expand Down
28 changes: 14 additions & 14 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import warnings
from typing import Dict
from typing import List

from exareme2.controller import config as ctrl_config
from exareme2.controller import logger as ctrl_logger
from exareme2.controller.federation_info_logs import log_experiment_execution
from exareme2.controller.services.flower.tasks_handler import FlowerTasksHandler
Expand Down Expand Up @@ -52,10 +55,16 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
workers_info = self._get_workers_info_by_dataset(
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
)
workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
]
task_handlers = [
self._create_worker_tasks_handler(request_id, worker)
for worker in workers_info
Expand Down Expand Up @@ -87,7 +96,10 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
)
clients_pids = {
handler.start_flower_client(
algorithm_name, str(server_address)
algorithm_name,
str(server_address),
csv_paths_per_worker_id[handler.worker_id],
ctrl_config.flower_execution_timeout,
): handler
for handler in task_handlers
}
Expand Down Expand Up @@ -127,15 +139,3 @@ async def _cleanup(
server_task_handler.stop_flower_server(server_pid, algorithm_name)
for pid, handler in clients_pids.items():
handler.stop_flower_client(pid, algorithm_name)

def _get_workers_info_by_dataset(self, data_model, datasets) -> List[WorkerInfo]:
"""Retrieves worker information for those handling the specified datasets."""
worker_ids = (
self.worker_landscape_aggregator.get_worker_ids_with_any_of_datasets(
data_model, datasets
)
)
return [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in worker_ids
]
10 changes: 8 additions & 2 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ def worker_id(self) -> str:
def worker_data_address(self) -> str:
return self._db_address

def start_flower_client(self, algorithm_name, server_address) -> int:
def start_flower_client(
self, algorithm_name, server_address, csv_paths, execution_timeout
) -> int:
return self._worker_tasks_handler.start_flower_client(
self._request_id, algorithm_name, server_address
self._request_id,
algorithm_name,
server_address,
csv_paths,
execution_timeout,
).get(timeout=self._tasks_timeout)

def start_flower_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from exareme2.controller.celery.tasks_handler import WorkerTasksHandler
from exareme2.worker_communication import CommonDataElements
from exareme2.worker_communication import DataModelAttributes
from exareme2.worker_communication import DatasetsInfoPerDataModel
from exareme2.worker_communication import WorkerInfo


Expand All @@ -23,10 +24,11 @@ def get_worker_info_task(self) -> WorkerInfo:
).get(self._tasks_timeout)
return WorkerInfo.parse_raw(result)

def get_worker_datasets_per_data_model_task(self) -> Dict[str, Dict[str, str]]:
return self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
def get_worker_datasets_per_data_model_task(self) -> DatasetsInfoPerDataModel:
result = self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
self._request_id
).get(self._tasks_timeout)
return DatasetsInfoPerDataModel.parse_raw(result)

def get_data_model_cdes_task(self, data_model: str) -> CommonDataElements:
result = self._worker_tasks_handler.queue_data_model_cdes_task(
Expand Down
Loading
Loading