diff --git a/dags/modules/convert/ingest_csv_to_iceberg.py b/dags/modules/convert/ingest_csv_to_iceberg.py index e38d25ce..60938609 100644 --- a/dags/modules/convert/ingest_csv_to_iceberg.py +++ b/dags/modules/convert/ingest_csv_to_iceberg.py @@ -3,7 +3,7 @@ from dags.modules.databases.duckdb import file_csv_to_parquet from dags.modules.databases.trino import create_schema, drop_table, get_trino_conn_details, get_trino_engine, hive_create_table_from_parquet, iceberg_create_table_from_hive, validate_identifier, validate_s3_key, get_table_schema_and_max_values from dags.modules.utils.rabbit import send_message_to_rabbitmq -from dags.modules.utils.s3 import s3_create_bucket, s3_delete, s3_download_minio, s3_upload +from dags.modules.utils.s3 import s3_create_bucket, s3_delete, s3_download_minio, s3_upload, detect_if_secure_endpoint, get_conn_details from dags.modules.utils.tracking_timer import tracking_timer, tracking_data, tracking_data_str import constants @@ -297,13 +297,13 @@ def ingest_csv_to_iceberg(dataset, tablename, version, label, etag, ingest_bucke x=tracking_timer(p_conn, etag, "s_schema") logging.info("Getting schema from the new PAR file") - s3_conn = json.loads(BaseHook.get_connection("s3_conn").get_extra()) + s3_conn = get_conn_details("s3_conn") fs = s3fs.S3FileSystem( endpoint_url=s3_conn["endpoint_url"], key=s3_conn["aws_access_key_id"], secret=s3_conn["aws_secret_access_key"], - use_ssl=True + use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"]) ) with fs.open(F"s3://{hive_bucket}/{hive_key}", "rb") as fp: diff --git a/dags/modules/databases/duckdb.py b/dags/modules/databases/duckdb.py index a4a40798..efd5629f 100644 --- a/dags/modules/databases/duckdb.py +++ b/dags/modules/databases/duckdb.py @@ -4,7 +4,7 @@ import duckdb from airflow.hooks.base import BaseHook -from ..utils.s3 import validate_s3_key, get_conn_details +from ..utils.s3 import validate_s3_key, get_conn_details, detect_if_secure_endpoint logger = logging.getLogger(__name__) @@ -20,6 +20,7 @@ def s3_csv_to_parquet(conn_id: str, src_bucket: str, dst_bucket: str, src_key: s access_key_id = s3_conn['aws_access_key_id'] secret_access_key = s3_conn['aws_secret_access_key'] endpoint = s3_conn["endpoint_url"].replace("http://", "").replace("https://", "") + is_secure = detect_if_secure_endpoint(s3_conn["endpoint_url"]) # original duckdb # con = duckdb.connect(database=':memory:') @@ -33,7 +34,7 @@ def s3_csv_to_parquet(conn_id: str, src_bucket: str, dst_bucket: str, src_key: s f"SET s3_endpoint='{endpoint}';" \ f"SET s3_access_key_id='{access_key_id}';" \ f"SET s3_secret_access_key='{secret_access_key}';" \ - f"SET s3_use_ssl=True;" \ + f"SET s3_use_ssl={is_secure};" \ f"SET preserve_insertion_order=False;" \ f"SET s3_url_style='path';" \ f"SET memory_limit='{memory}GB'" diff --git a/dags/modules/databases/trino.py b/dags/modules/databases/trino.py index dca4b6ba..352735a0 100644 --- a/dags/modules/databases/trino.py +++ b/dags/modules/databases/trino.py @@ -50,6 +50,10 @@ def get_trino_engine(trino_conn_details: dict) -> sqlalchemy.engine.Engine: host = trino_conn_details['host'] port = trino_conn_details['port'] database = trino_conn_details['database'] + connect_protocol = "http" + if port == 443: + connect_protocol = "https" + logger.info(f"username={username}") logger.info(f"host={host}") @@ -59,7 +63,7 @@ def get_trino_engine(trino_conn_details: dict) -> sqlalchemy.engine.Engine: engine = create_engine( f"trino://{username}@{host}:{port}/{database}", connect_args={ - "http_scheme": "http", + "http_scheme": connect_protocol, # TODO This needs to be set to true when deploying to anything thats not dev "verify": False }, @@ -165,31 +169,30 @@ def drop_table(trino: sqlalchemy.engine.Engine, table: str): trino.execute(query) def get_schema(engine, table_name): - query = f"SHOW COLUMNS FROM {table_name}" + query = f"SHOW COLUMNS FROM {table_name}" + try: + result = engine.execute(text(query)) + schema = {row['Column']: row['Type'] for row in result} + return schema + except SQLAlchemyError as e: + print(f"Error executing query: {e}") + return None + +def get_max_values(engine, table_name, schema): + max_values = {} + for column in schema.keys(): + query = f"SELECT MAX({column}) as max_value FROM {table_name}" + if schema[column] == "varchar": + query = f"SELECT MAX(length({column})) as max_length FROM {table_name}" try: - result = engine.execute(text(query)) - schema = {row['Column']: row['Type'] for row in result} - return schema + result = engine.execute(text(query)).scalar() + max_values[column] = result except SQLAlchemyError as e: print(f"Error executing query: {e}") - return None - -def get_max_values(engine, table_name, schema): - max_values = {} - for column in schema.keys(): - query = f"SELECT MAX({column}) as max_value FROM {table_name}" - if schema[column] == "varchar": - query = f"SELECT MAX(length({column})) as max_length FROM {table_name}" - try: - result = engine.execute(text(query)).scalar() - max_values[column] = result - except SQLAlchemyError as e: - print(f"Error executing query: {e}") - max_values[column] = None - return max_values + max_values[column] = None + return max_values def get_table_schema_and_max_values(trino: sqlalchemy.engine.Engine, table_name ): - # Reflect the table from the database schema = get_schema(trino,table_name) diff --git a/dags/modules/utils/s3.py b/dags/modules/utils/s3.py index 434d67b5..093db33b 100644 --- a/dags/modules/utils/s3.py +++ b/dags/modules/utils/s3.py @@ -3,13 +3,23 @@ import logging import boto3 from minio import Minio -#from minio.error import ResponseError +# from minio.error import ResponseError import s3fs from airflow.hooks.base import BaseHook from botocore.config import Config logger = logging.getLogger(__name__) + +def detect_if_secure_endpoint(endpoint_url: str) -> bool: + https_re = r'^https://' + + if re.match(https_re, endpoint_url): + return True + + return False + + def get_conn_details(conn_id) -> dict: """Gets S3 connection info from Airflow connection with given connection id @@ -23,9 +33,11 @@ def get_conn_details(conn_id) -> dict: logger.critical("There is no Airflow connection configured with the name {0}.".format(conn_id)) raise e # Transform connection object into dictionary to make it easier to share between modules - s3_conn = {"endpoint_url": conn_extra["endpoint_url"], "aws_access_key_id": conn.login, "aws_secret_access_key": conn.password} + s3_conn = {"endpoint_url": conn_extra["endpoint_url"], "aws_access_key_id": conn.login, + "aws_secret_access_key": conn.password} return s3_conn + def validate_s3_key(key): # Validate the s3 key is strictly one or more slash separated keys logging.info(f"Validate key: {key}") @@ -51,19 +63,17 @@ def s3_get_resource(conn_id: str): def s3_get_fs(conn_id): - s3_conn = get_conn_details(conn_id) return s3fs.S3FileSystem( endpoint_url=s3_conn["endpoint_url"], key=s3_conn["aws_access_key_id"], secret=s3_conn["aws_secret_access_key"], - use_ssl=True + use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"]) ) def s3_copy(conn_id: str, src_bucket, dst_bucket, src_key, dst_key, move=False): - s3 = s3_get_resource(conn_id) s3.Bucket(dst_bucket).copy({'Bucket': src_bucket, 'Key': src_key}, dst_key) @@ -73,7 +83,6 @@ def s3_copy(conn_id: str, src_bucket, dst_bucket, src_key, dst_key, move=False): def s3_delete(conn_id: str, bucket, key): - s3_get_resource(conn_id).Object(bucket, key).delete() @@ -84,22 +93,19 @@ def s3_create_bucket(conn_id: str, bucket): print(f"An error occurred while creating the S3 bucket: {e}") - def s3_download_minio(conn_id, bucket_name, object_name, local_file_path): - s3_conn = get_conn_details(conn_id) - url = str(s3_conn["endpoint_url"]).replace('http://','').replace('https://','') + url = str(s3_conn["endpoint_url"]).replace('http://', '').replace('https://', '') client = Minio(url, - access_key=s3_conn["aws_access_key_id"], - secret_key=s3_conn["aws_secret_access_key"], - secure=True) - - - #client.fget_object(bucket_name, object_name, local_file_path) - - #try: + access_key=s3_conn["aws_access_key_id"], + secret_key=s3_conn["aws_secret_access_key"], + secure=detect_if_secure_endpoint(s3_conn["endpoint_url"])) + + # client.fget_object(bucket_name, object_name, local_file_path) + + # try: # Start a multipart download response = client.get_object( bucket_name, @@ -115,12 +121,11 @@ def s3_download_minio(conn_id, bucket_name, object_name, local_file_path): print("Download successful!") - #except ResponseError as err: + # except ResponseError as err: # print(err) def s3_download(conn_id, bucket_name, object_name, local_file_path): - s3_conn = get_conn_details(conn_id) client = boto3.client( @@ -128,14 +133,13 @@ def s3_download(conn_id, bucket_name, object_name, local_file_path): endpoint_url=s3_conn["endpoint_url"], aws_access_key_id=s3_conn["aws_access_key_id"], aws_secret_access_key=s3_conn["aws_secret_access_key"], - use_ssl=True ) + use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"])) with open(local_file_path, 'wb') as f: client.download_fileobj(bucket_name, object_name, f) def s3_upload(conn_id, src_file, bucket, object_name): - s3_conn = get_conn_details(conn_id) client = boto3.client( @@ -143,7 +147,6 @@ def s3_upload(conn_id, src_file, bucket, object_name): endpoint_url=s3_conn["endpoint_url"], aws_access_key_id=s3_conn["aws_access_key_id"], aws_secret_access_key=s3_conn["aws_secret_access_key"], - use_ssl=True ) + use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"])) - client.upload_file(src_file,bucket,object_name) - \ No newline at end of file + client.upload_file(src_file, bucket, object_name)