Skip to content

Commit

Permalink
Data are loaded on the globalworker too.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kostas Filippopolitis committed Jul 12, 2024
1 parent bddfea1 commit af0de5e
Show file tree
Hide file tree
Showing 14 changed files with 245 additions and 173 deletions.
137 changes: 68 additions & 69 deletions .github/workflows/prod_env_tests.yml

Large diffs are not rendered by default.

17 changes: 1 addition & 16 deletions exareme2/algorithms/flower/inputdata_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def apply_inputdata(df: pd.DataFrame, inputdata: Inputdata) -> pd.DataFrame:
return df


def fetch_client_data(inputdata) -> pd.DataFrame:
FLOWER_LOGGER.error(f"BROOO {os.getenv('CSV_PATHS')}")
def fetch_data(inputdata) -> pd.DataFrame:
dataframes = [
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
for csv_path in os.getenv("CSV_PATHS").split(",")
Expand All @@ -49,20 +48,6 @@ def fetch_client_data(inputdata) -> pd.DataFrame:
return apply_inputdata(df, inputdata)


def fetch_server_data(inputdata) -> pd.DataFrame:
data_folder = Path(
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
)
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in inputdata.datasets
if (data_folder / f"{dataset}.csv").exists()
]
df = pd.concat(dataframes, ignore_index=True)
return apply_inputdata(df, inputdata)


def preprocess_data(inputdata, full_data):
# Ensure x and y are specified and correct
if not inputdata.x or not inputdata.y:
Expand Down
4 changes: 2 additions & 2 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.inputdata_preprocessing import fetch_client_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data

Expand Down Expand Up @@ -42,7 +42,7 @@ def evaluate(self, parameters, config):
if __name__ == "__main__":
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
inputdata = get_input()
full_data = fetch_client_data(inputdata)
full_data = fetch_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)

Expand Down
4 changes: 2 additions & 2 deletions exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.inputdata_preprocessing import fetch_server_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data
Expand Down Expand Up @@ -35,7 +35,7 @@ def evaluate(server_round, parameters, config):
if __name__ == "__main__":
model = LogisticRegression()
inputdata = get_input()
full_data = fetch_server_data(inputdata)
full_data = fetch_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)
strategy = fl.server.strategy.FedAvg(
Expand Down
3 changes: 2 additions & 1 deletion exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,15 @@ def start_flower_client(
)

def start_flower_server(
self, request_id, algorithm_name, number_of_clients, server_address
self, request_id, algorithm_name, number_of_clients, server_address, csv_paths
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_server"],
request_id=request_id,
algorithm_name=algorithm_name,
number_of_clients=number_of_clients,
server_address=server_address,
csv_paths=csv_paths,
)

def stop_flower_server(
Expand Down
7 changes: 5 additions & 2 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
request_id, global_worker
)
server_ip = global_worker.ip

server_id = global_worker.id
# Garbage Collect
server_task_handler.garbage_collect()
for handler in task_handlers:
Expand All @@ -92,7 +92,10 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):

try:
server_pid = server_task_handler.start_flower_server(
algorithm_name, len(task_handlers), str(server_address)
algorithm_name,
len(task_handlers),
str(server_address),
csv_paths_per_worker_id[server_id],
)
clients_pids = {
handler.start_flower_client(
Expand Down
8 changes: 6 additions & 2 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ def start_flower_client(
).get(timeout=self._tasks_timeout)

def start_flower_server(
self, algorithm_name: str, number_of_clients: int, server_address
self, algorithm_name: str, number_of_clients: int, server_address, csv_paths
) -> int:
return self._worker_tasks_handler.start_flower_server(
self._request_id, algorithm_name, number_of_clients, server_address
self._request_id,
algorithm_name,
number_of_clients,
server_address,
csv_paths,
).get(timeout=self._tasks_timeout)

def stop_flower_server(self, pid: int, algorithm_name: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,13 +557,8 @@ def _fetch_workers_metadata(
.socket_addresses
)
workers_info = self._get_workers_info(workers_addresses)
local_workers = [
worker_info
for worker_info in workers_info
if worker_info.role == WorkerRole.LOCALWORKER
]
data_models_metadata_per_worker = self._get_data_models_metadata_per_worker(
local_workers
workers_info
)
return workers_info, data_models_metadata_per_worker

Expand Down
8 changes: 6 additions & 2 deletions exareme2/worker/flower/starter/starter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ def start_flower_client(

@shared_task
def start_flower_server(
request_id: str, algorithm_name: str, number_of_clients: int, server_address
request_id: str,
algorithm_name: str,
number_of_clients: int,
server_address,
csv_paths,
) -> int:
return starter_service.start_flower_server(
request_id, algorithm_name, number_of_clients, server_address
request_id, algorithm_name, number_of_clients, server_address, csv_paths
)
7 changes: 6 additions & 1 deletion exareme2/worker/flower/starter/starter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def start_flower_client(

@initialise_logger
def start_flower_server(
request_id: str, algorithm_name: str, number_of_clients: int, server_address
request_id: str,
algorithm_name: str,
number_of_clients: int,
server_address,
csv_paths,
) -> int:
env_vars = {
"REQUEST_ID": request_id,
Expand All @@ -47,6 +51,7 @@ def start_flower_server(
"CONTROLLER_IP": worker_config.controller.ip,
"CONTROLLER_PORT": worker_config.controller.port,
"DATA_PATH": worker_config.data_path,
"CSV_PATHS": ",".join(csv_paths),
}
process = FlowerProcess(f"{algorithm_name}/server.py", env_vars=env_vars)
logger = get_logger()
Expand Down
2 changes: 0 additions & 2 deletions exareme2/worker/worker_info/worker_info_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import warnings
from typing import Dict
from typing import List

Expand Down Expand Up @@ -38,7 +37,6 @@ def get_data_models() -> List[str]:


def convert_absolute_dataset_path_to_relative(dataset_path: str) -> str:
warnings.warn(str(dataset_path))
return dataset_path.split(str(worker_config.data_path))[-1]


Expand Down
19 changes: 19 additions & 0 deletions kubernetes/templates/exareme2-globalnode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ spec:
- "select 1;"
periodSeconds: 30

- name: db-importer
image: {{ .Values.exareme2_images.repository }}/exareme2_mipdb:{{ .Values.exareme2_images.version }}
env:
- name: DB_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
- name: SQLITE_DB_NAME
valueFrom:
fieldRef:
fieldPath: spec.nodeName
- name: DB_PORT
value: "50000"
volumeMounts:
- mountPath: /opt/data
name: csv-data
- mountPath: /opt/credentials
name: credentials

- name: rabbitmq
image: {{ .Values.exareme2_images.repository }}/exareme2_rabbitmq:{{ .Values.exareme2_images.version }}
env:
Expand Down
Loading

0 comments on commit af0de5e

Please sign in to comment.