diff --git a/exareme2/algorithms/flower/logistic_regression/client.py b/exareme2/algorithms/flower/logistic_regression/client.py index 72620d9d4..292b73215 100644 --- a/exareme2/algorithms/flower/logistic_regression/client.py +++ b/exareme2/algorithms/flower/logistic_regression/client.py @@ -1,7 +1,10 @@ import os +import time import warnings +from math import log2 import flwr as fl +from flwr.common.logger import FLOWER_LOGGER from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss from utils import get_model_parameters @@ -44,6 +47,22 @@ def evaluate(self, parameters, config): set_initial_params(model, X_train, full_data, inputdata) client = LogisticRegressionClient(model, X_train, y_train) - fl.client.start_client( - server_address=os.environ["SERVER_ADDRESS"], client=client.to_client() - ) + + attempts = 0 + max_attempts = int(log2(int(os.environ["TIMEOUT"]))) + while True: + try: + fl.client.start_client( + server_address=os.environ["SERVER_ADDRESS"], client=client.to_client() + ) + FLOWER_LOGGER.debug("Connection successful on attempt", attempts + 1) + break + except Exception as e: + FLOWER_LOGGER.warning( + f"Connection with the server failed. Attempt {attempts} failed: {e}" + ) + time.sleep(pow(2, attempts)) + attempts += 1 + if attempts >= max_attempts: + FLOWER_LOGGER.error("Could not establish connection to the server.") + raise e diff --git a/exareme2/controller/celery/tasks_handler.py b/exareme2/controller/celery/tasks_handler.py index 3c4eda394..33e5bb0a4 100644 --- a/exareme2/controller/celery/tasks_handler.py +++ b/exareme2/controller/celery/tasks_handler.py @@ -298,13 +298,14 @@ def queue_healthcheck_task( ) def start_flower_client( - self, request_id, algorithm_name, server_address + self, request_id, algorithm_name, server_address, execution_timeout ) -> WorkerTaskResult: return self._queue_task( task_signature=TASK_SIGNATURES["start_flower_client"], request_id=request_id, algorithm_name=algorithm_name, server_address=server_address, + execution_timeout=execution_timeout, ) def start_flower_server( diff --git a/exareme2/controller/services/flower/controller.py b/exareme2/controller/services/flower/controller.py index fcfe9cb8f..d60a202e1 100644 --- a/exareme2/controller/services/flower/controller.py +++ b/exareme2/controller/services/flower/controller.py @@ -1,6 +1,7 @@ import asyncio from typing import List +from exareme2.controller import config as ctrl_config from exareme2.controller import logger as ctrl_logger from exareme2.controller.federation_info_logs import log_experiment_execution from exareme2.controller.services.flower.tasks_handler import FlowerTasksHandler @@ -87,7 +88,9 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto): ) clients_pids = { handler.start_flower_client( - algorithm_name, str(server_address) + algorithm_name, + str(server_address), + ctrl_config.flower_execution_timeout, ): handler for handler in task_handlers } diff --git a/exareme2/controller/services/flower/tasks_handler.py b/exareme2/controller/services/flower/tasks_handler.py index ac731dff2..75f053a85 100644 --- a/exareme2/controller/services/flower/tasks_handler.py +++ b/exareme2/controller/services/flower/tasks_handler.py @@ -29,9 +29,11 @@ def worker_id(self) -> str: def worker_data_address(self) -> str: return self._db_address - def start_flower_client(self, algorithm_name, server_address) -> int: + def start_flower_client( + self, algorithm_name, server_address, execution_timeout + ) -> int: return self._worker_tasks_handler.start_flower_client( - self._request_id, algorithm_name, server_address + self._request_id, algorithm_name, server_address, execution_timeout ).get(timeout=self._tasks_timeout) def start_flower_server( diff --git a/exareme2/worker/flower/starter/starter_api.py b/exareme2/worker/flower/starter/starter_api.py index 8a85ba561..fe710b6db 100644 --- a/exareme2/worker/flower/starter/starter_api.py +++ b/exareme2/worker/flower/starter/starter_api.py @@ -4,9 +4,11 @@ @shared_task -def start_flower_client(request_id: str, algorithm_name, server_address) -> int: +def start_flower_client( + request_id: str, algorithm_name, server_address, execution_timeout +) -> int: return starter_service.start_flower_client( - request_id, algorithm_name, server_address + request_id, algorithm_name, server_address, execution_timeout ) diff --git a/exareme2/worker/flower/starter/starter_service.py b/exareme2/worker/flower/starter/starter_service.py index 70cdbabff..387a94235 100644 --- a/exareme2/worker/flower/starter/starter_service.py +++ b/exareme2/worker/flower/starter/starter_service.py @@ -5,7 +5,9 @@ @initialise_logger -def start_flower_client(request_id: str, algorithm_name, server_address) -> int: +def start_flower_client( + request_id: str, algorithm_name, server_address, execution_timeout +) -> int: env_vars = { "MONETDB_IP": worker_config.monetdb.ip, "MONETDB_PORT": worker_config.monetdb.port, @@ -20,6 +22,7 @@ def start_flower_client(request_id: str, algorithm_name, server_address) -> int: "CONTROLLER_IP": worker_config.controller.ip, "CONTROLLER_PORT": worker_config.controller.port, "DATA_PATH": worker_config.data_path, + "TIMEOUT": execution_timeout, } process = FlowerProcess(f"{algorithm_name}/client.py", env_vars=env_vars) logger = get_logger()