diff --git a/pipelines/rj_smtr/utils.py b/pipelines/rj_smtr/utils.py index 89bd2a3c7..16ed538d3 100644 --- a/pipelines/rj_smtr/utils.py +++ b/pipelines/rj_smtr/utils.py @@ -7,8 +7,8 @@ from ftplib import FTP from pathlib import Path -from datetime import timedelta, datetime -from typing import List, Union +from datetime import timedelta, datetime, date +from typing import List, Union, Any import traceback import io import json @@ -462,17 +462,40 @@ def dict_contains_keys(input_dict: dict, keys: list[str]) -> bool: return all(x in input_dict.keys() for x in keys) +def custom_serialization(obj: Any) -> Any: + """ + Function to serialize not JSON serializable objects + + Args: + obj (Any): Object to serialize + + Returns: + Any: Serialized object + """ + if isinstance(obj, (datetime, date, pd.Timestamp)): + if obj.tzinfo is None: + obj = obj.tz_localize(emd_constants.DEFAULT_TIMEZONE.value) + else: + obj = obj.tz_convert(emd_constants.DEFAULT_TIMEZONE.value) + + return obj.isoformat() + + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + def save_raw_local_func( - data: Union[dict, str], filepath: str, mode: str = "raw", filetype: str = "json" + data: Union[dict, str], + filepath: str, + mode: str = "raw", + filetype: str = "json", ) -> str: """ Saves json response from API to .json file. Args: + data (Union[dict, str]): Raw data to save filepath (str): Path which to save raw file - status (dict): Must contain keys - * data: json returned from API - * error: error catched from API request mode (str, optional): Folder to save locally, later folder which to upload to GCS. + filetype (str, optional): The file format Returns: str: Path to the saved file """ @@ -485,10 +508,8 @@ def save_raw_local_func( if isinstance(data, str): data = json.loads(data) with Path(_filepath).open("w", encoding="utf-8") as fi: - json.dump(data, fi) + json.dump(data, fi, default=custom_serialization) - # if filetype == "csv": - # pass if filetype in ("txt", "csv"): with open(_filepath, "w", encoding="utf-8") as file: file.write(data) @@ -630,17 +651,9 @@ def get_raw_data_db( Returns: tuple[str, str, str]: Error, data and filetype """ - connection_mapping = { - "postgresql": { - "connector": psycopg2.connect, - "port": "5432", - "cursor": {"cursor_factory": psycopg2.extras.RealDictCursor}, - }, - "mysql": { - "connector": pymysql.connect, - "port": "3306", - "cursor": {"cursor": pymysql.cursors.DictCursor}, - }, + connector_mapping = { + "postgresql": psycopg2.connect, + "mysql": pymysql.connect, } data = None @@ -650,19 +663,14 @@ def get_raw_data_db( try: credentials = get_vault_secret(secret_path)["data"] - connection = connection_mapping[engine]["connector"]( + with connector_mapping[engine]( host=host, user=credentials["user"], password=credentials["password"], database=database, - ) - - with connection: - with connection.cursor(**connection_mapping[engine]["cursor"]) as cursor: - cursor.execute(query) - data = cursor.fetchall() + ) as connection: + data = pd.read_sql(sql=query, con=connection).to_dict(orient="records") - data = [dict(d) for d in data] except Exception: error = traceback.format_exc() log(f"[CATCHED] Task failed with error: \n{error}", level="error")