Skip to content

Commit

Permalink
Add flower logger similar to the worker
Browse files Browse the repository at this point in the history
  • Loading branch information
ThanKarab committed Jun 16, 2024
1 parent 5a66f1c commit 8354186
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 14 deletions.
23 changes: 23 additions & 0 deletions exareme2/algorithms/flower/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import logging
import os

from flwr.common.logger import FLOWER_LOGGER

for handler in FLOWER_LOGGER.handlers:
FLOWER_LOGGER.removeHandler(handler)

FLOWER_LOGGER.setLevel(logging.DEBUG)

request_id = os.getenv("REQUEST_ID", "NO-REQUEST_ID")
worker_role = os.getenv("WORKER_ROLE", "NO-ROLE")
worker_identifier = os.getenv("WORKER_IDENTIFIER", "NO-IDENTIFIER")

flower_formatter = logging.Formatter(
f"%(asctime)s - %(levelname)s - FLOWER - {worker_role} - {worker_identifier} - %(module)s - %(funcName)s(%(lineno)d) - {request_id} - %(message)s"
)

# Configure console logger
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(flower_formatter)
FLOWER_LOGGER.addHandler(console_handler)
16 changes: 6 additions & 10 deletions exareme2/algorithms/flower/flower_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import pymonetdb
import requests
from flwr.common.logger import FLOWER_LOGGER
from pydantic import BaseModel
from sklearn import preprocessing
from sklearn.impute import SimpleImputer
Expand Down Expand Up @@ -53,7 +54,6 @@ def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:

def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in datasets
Expand Down Expand Up @@ -85,21 +85,21 @@ def preprocess_data(inputdata, full_data):

def error_handling(error):
error_msg = {"error": str(error)}
print(
FLOWER_LOGGER.error(
f"Error will try to save error message: {error_msg}! Running: {RESULT_URL}..."
)
requests.post(RESULT_URL, data=json.dumps(error_msg), headers=HEADERS)


def post_result(result: dict) -> None:
print(f"Running: {RESULT_URL}...")
FLOWER_LOGGER.debug(f"Posting result at: {RESULT_URL} ...")
response = requests.post(RESULT_URL, data=json.dumps(result), headers=HEADERS)
if response.status_code != 200:
error_handling(response.text)


def get_input() -> Inputdata:
print(f"Running: {INPUT_URL}...")
FLOWER_LOGGER.debug(f"Getting inputdata from: {INPUT_URL} ...")
response = requests.get(INPUT_URL)
if response.status_code != 200:
error_handling(response.text)
Expand All @@ -109,7 +109,7 @@ def get_input() -> Inputdata:

def get_enumerations(data_model: str, variable_name: str) -> list:
try:
print(f"Running: {CDES_URL}...")
FLOWER_LOGGER.debug(f"Getting enumerations from: {CDES_URL} ...")
response = requests.get(CDES_URL)
if response.status_code != 200:
error_handling(response.text)
Expand All @@ -126,8 +126,4 @@ def get_enumerations(data_model: str, variable_name: str) -> list:
else:
raise KeyError(f"'enumerations' key not found in {variable_name}")
except (requests.RequestException, KeyError, json.JSONDecodeError) as e:
error_msg = {"error": str(e)}
print(
f"Error will try to save error message: {error_msg}! Running: {RESULT_URL}..."
)
requests.post(RESULT_URL, data=json.dumps(error_msg), headers=HEADERS)
error_handling(str(e))
1 change: 0 additions & 1 deletion exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def evaluate(server_round, parameters, config):
accuracy = model.score(X_test, y_test)
if server_round == NUM_OF_ROUNDS:
post_result({"accuracy": accuracy})
print({"accuracy": accuracy})
return loss, {"accuracy": accuracy}

return evaluate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import flwr as fl
import numpy as np
from flwr.common.logger import FLOWER_LOGGER
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

Expand Down Expand Up @@ -49,14 +50,13 @@ def fit(self, parameters, config):
]
return_data = (params, len(X_train), {"accuracy": accuracy})
except Exception as e:
print(f"Error during model fitting: {e}")
FLOWER_LOGGER.error(f"Error during model fitting: {e}")
# On error, default to zero-initialized parameters, no training examples, and zero accuracy
zero_params = [
np.zeros_like(param) for param in utils.get_model_parameters(model)
]
return_data = (zero_params, 0, {"accuracy": 0.0})

print(f"Returning from fit: {return_data}")
return return_data

def evaluate(self, parameters, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def evaluate(server_round, parameters: fl.common.NDArrays, config):
accuracy = model.score(X_test, y_test)
if server_round == NUM_OF_ROUNDS:
post_result({"accuracy": accuracy})
print({"accuracy": accuracy})
return loss, {"accuracy": accuracy}

return evaluate
Expand Down
6 changes: 6 additions & 0 deletions exareme2/worker/flower/starter/starter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def start_flower_client(request_id: str, algorithm_name, server_address) -> int:
"MONETDB_USERNAME": worker_config.monetdb.local_username,
"MONETDB_PASSWORD": worker_config.monetdb.local_password,
"MONETDB_DB": worker_config.monetdb.database,
"REQUEST_ID": request_id,
"WORKER_ROLE": worker_config.role,
"WORKER_IDENTIFIER": worker_config.identifier,
"SERVER_ADDRESS": server_address,
"NUMBER_OF_CLIENTS": worker_config.monetdb.database,
"CONTROLLER_IP": worker_config.controller.ip,
Expand All @@ -32,6 +35,9 @@ def start_flower_server(
request_id: str, algorithm_name: str, number_of_clients: int, server_address
) -> int:
env_vars = {
"REQUEST_ID": request_id,
"WORKER_ROLE": worker_config.role,
"WORKER_IDENTIFIER": worker_config.identifier,
"SERVER_ADDRESS": server_address,
"NUMBER_OF_CLIENTS": number_of_clients,
"CONTROLLER_IP": worker_config.controller.ip,
Expand Down

0 comments on commit 8354186

Please sign in to comment.