diff --git a/dags/ingest_csv_to_iceberg.py b/dags/ingest_csv_to_iceberg.py new file mode 100644 index 00000000..e7fcd23c --- /dev/null +++ b/dags/ingest_csv_to_iceberg.py @@ -0,0 +1,194 @@ +import os.path +import logging +import pendulum +from airflow import DAG +from airflow.operators.python import get_current_context, task + +from modules.databases.duckdb import s3_csv_to_parquet +from modules.utils.s3 import s3_delete +from modules.utils.sql import escape_dataset +from modules.utils.sha1 import sha1 +from modules.databases.trino import ( + create_schema, + drop_table, + get_trino_conn_details, + get_trino_engine, + hive_create_table_from_parquet, + iceberg_create_table_from_hive, + validate_identifier, + validate_s3_key +) + + +with DAG( + dag_id="ingest_csv_to_iceberg", + schedule=None, + start_date=pendulum.datetime(1900, 1, 1, tz="UTC"), + catchup=True, + max_active_runs=1, + concurrency=1, + tags=["ingest", "csv", "iceberg", "s3"], +) as dag: + + # Makes this logging namespace appear immediately in airflow + logging.info("DAG parsing...") + + @task + def ingest_csv_to_iceberg(): + + ######################################################################## + logging.info("Starting task ingest_csv_to_iceberg...") + + # Extract the Airflow context object for this run of the DAG + context = get_current_context() + logging.info(f"context={context}") + + # Extract the task instance object for handling XCOM variables + ti = context['ti'] + logging.info(f"ti={ti}") + + # Extract the JSON dict of params for this run of the DAG + conf = context['dag_run'].conf + logging.info(f"conf={conf}") + + # Unique hashed name for this run of the DAG + dag_hash = sha1( + f"dag_id={ti.dag_id}/" + f"run_id={ti.run_id}/" + f"task_id={ti.task_id}" + ) + logging.info(f"dag_hash={dag_hash}") + + dag_id = f"{dag_hash}_{ti.try_number}" + logging.info(f"dag_id={dag_id}") + + ######################################################################## + logging.info("Validate inputs...") + + debug = conf.get("debug", False) + logging.info(f"debug={debug}") + assert isinstance(debug, bool) + + # Path to the data file within the ingest bucket excluding the bucket name + ingest_key = conf.get("ingest_key", None) + logging.info(f"ingest_key={ingest_key}") + assert (ingest_key is not None) and \ + isinstance(ingest_key, str) and \ + ingest_key.endswith(".csv") + + ingest_path = os.path.dirname(ingest_key) + ingest_file = os.path.basename(ingest_key) + ingest_bucket = "ingest" + + ingest_delete = conf.get("ingest_delete", False) + logging.info(f"ingest_delete={ingest_delete}") + assert isinstance(ingest_delete, bool) + + # Base name of the dataset to provision, defaults to an escaped version of the path + dataset = escape_dataset(conf.get("dataset", ingest_path)) + + ingest = { + "bucket": ingest_bucket, + "key": ingest_key, + "path": ingest_path, + "file": ingest_file, + "dataset": dataset, + "delete": ingest_delete, + } + logging.info(f"ingest={ingest}") + ti.xcom_push("ingest", ingest) + + hive_schema = "minio.csv" + hive_table = validate_identifier(f"{hive_schema}.{dataset}_{dag_id}") + hive_bucket = "loading" + hive_dir = validate_s3_key(f"ingest/{dataset}/{dag_id}") + hive_file, _ = os.path.splitext(ingest_file) + hive_key = f"{hive_dir}/{hive_file}.parquet" + hive_path = validate_s3_key(f"{hive_bucket}/{hive_dir}") + hive = { + "schema": hive_schema, + "table": hive_table, + "bucket": hive_bucket, + "dir": hive_dir, + "file": hive_file, + "key": hive_key, + "path": hive_path, + } + logging.info(f"hive={hive}") + ti.xcom_push("hive", hive) + + iceberg_schema = "iceberg.ingest" + iceberg_table = validate_identifier(f"{iceberg_schema}.{dataset}_{dag_id}") + iceberg_bucket = "working" + iceberg_dir = validate_s3_key(f"ingest/{dataset}/{dag_id}") + iceberg_path = validate_s3_key(f"{iceberg_bucket}/{iceberg_dir}") + iceberg = { + "schema": iceberg_schema, + "table": iceberg_table, + "bucket": iceberg_bucket, + "dir": iceberg_dir, + "path": iceberg_path, + } + logging.info(f"iceberg={iceberg}") + ti.xcom_push("iceberg", iceberg) + + ######################################################################## + logging.info("Convert from ingest bucket CSV to loading bucket Parquet using DuckDB...") + s3_csv_to_parquet( + conn_id="s3_conn", + src_bucket=ingest_bucket, + src_key=ingest_key, + dst_bucket=hive_bucket, + dst_key=hive_key + ) + + if ingest_delete: + s3_delete(conn_id="s3_conn", bucket=ingest_bucket, key=ingest_key) + + ######################################################################## + logging.info("Mounting CSV on s3 into Hive connector and copy to Iceberg...") + + # Create a connection to Trino + trino_conn = get_trino_conn_details() + trino = get_trino_engine(trino_conn) + + logging.info("Create schema in Hive connector...") + create_schema(trino, schema=hive_schema, location=hive_bucket) + + logging.info("Create schema in Iceberg connector...") + create_schema(trino, schema=iceberg_schema, location=iceberg_bucket) + + try: + logging.info("Create table in Hive connector...") + hive_create_table_from_parquet( + trino, + table=hive_table, + location=hive_path + ) + + logging.info("Create table in Iceberg connector...") + iceberg_create_table_from_hive( + trino, + table=iceberg_table, + hive_table=hive_table, + location=iceberg_path + ) + + finally: + if debug: + logging.info("Debug mode, not cleaning up table in Hive connector...") + else: + logging.info("Cleanup table in Hive connector...") + drop_table(trino, table=hive_table) + + logging.info("Cleanup data from Hive connector in s3...") + # External location data is not cascade deleted on drop table + s3_delete( + conn_id="s3_conn", + bucket=hive_bucket, + key=hive_key + ) + + ######################################################################## + + ingest_csv_to_iceberg() diff --git a/dags/ingest_csv_to_parquet.py b/dags/ingest_csv_to_parquet.py index 94074ae9..5a173955 100644 --- a/dags/ingest_csv_to_parquet.py +++ b/dags/ingest_csv_to_parquet.py @@ -38,10 +38,11 @@ def unpack_minio_event(message): bucket = records["s3"]["bucket"]["name"] etag = s3_object["eTag"] - src_file_path: str = s3_object["key"] + src_file_path: str = s3_object["key"].replace('%2F', '/') assert src_file_path.endswith(".csv") - file_name = src_file_path.replace(".csv", "").replace('%2F', '_') + file_name = src_file_path.replace(".csv", "") + dir_name = src_file_path.split("/")[0] full_file_path = message_json['Key'] head_path = '/'.join(full_file_path.split('/')[:-1]) @@ -52,6 +53,7 @@ def unpack_minio_event(message): src_file_path=src_file_path, etag=etag, file_name=file_name, + dir_name=dir_name, full_file_path=full_file_path, head_path=head_path ) @@ -73,6 +75,9 @@ def sha1(value): tags=["ingest", "csv", "parquet", "s3"], ) as dag: + # Makes this logging namespace appear immediately in airflow + logger.info("DAG parsing...") + def process_event(message): logger.info("Processing message!") logger.info(f"message={message}") @@ -125,7 +130,7 @@ def process_event(message): ) logger.info(f"hive table schema={hive_schema_str}") - table_name = re.sub(r"[^a-zA-Z0-9]", '_', event['file_name']).strip().strip('_').strip() + table_name = re.sub(r"[^a-zA-Z0-9]", '_', event['dir_name']).strip().strip('_').strip() logger.info(f"table name={table_name}") hive_table_name = f"{table_name}_{dag_hash}_{ti.try_number}" diff --git a/dags/modules/databases/duckdb.py b/dags/modules/databases/duckdb.py new file mode 100644 index 00000000..e57c29c6 --- /dev/null +++ b/dags/modules/databases/duckdb.py @@ -0,0 +1,42 @@ +import os +import json +import logging +import duckdb +from airflow.hooks.base import BaseHook + +from ..utils.s3 import validate_s3_key + +logger = logging.getLogger(__name__) + + +def s3_csv_to_parquet(conn_id: str, src_bucket: str, dst_bucket: str, src_key: str, dst_key: str, memory: int = 500): + + assert src_key.lower().endswith(".csv") + assert dst_key.lower().endswith(".parquet") + assert validate_s3_key(os.path.dirname(src_key)) + assert validate_s3_key(os.path.dirname(dst_key)) + + s3_conn = json.loads(BaseHook.get_connection(conn_id).get_extra()) + access_key_id = s3_conn['aws_access_key_id'] + secret_access_key = s3_conn['aws_secret_access_key'] + endpoint = s3_conn["endpoint_url"]\ + .replace("http://", "").replace("https://", "") + + con = duckdb.connect(database=':memory:') + + query = f"INSTALL '/opt/duckdb/httpfs.duckdb_extension';" \ + f"LOAD httpfs;" \ + f"SET s3_endpoint='{endpoint}';" \ + f"SET s3_access_key_id='{access_key_id}';" \ + f"SET s3_secret_access_key='{secret_access_key}';" \ + f"SET s3_use_ssl=False;" \ + f"SET s3_url_style='path';" \ + f"SET memory_limit='{memory}MB'" + logger.info(f"query={query}") + con.execute(query) + + query = f"COPY (SELECT * FROM 's3://{src_bucket}/{src_key}')" \ + f"TO 's3://{dst_bucket}/{dst_key}'" \ + f"(FORMAT PARQUET, CODEC 'SNAPPY', ROW_GROUP_SIZE 100000);" + logger.info(f"query={query}") + con.execute(query) diff --git a/dags/modules/databases/trino.py b/dags/modules/databases/trino.py index a1d24b84..7669e063 100644 --- a/dags/modules/databases/trino.py +++ b/dags/modules/databases/trino.py @@ -1,21 +1,9 @@ -"""Module controlling different methods to access and use trino. - -This module contains the methods to that will allow us to preform different actions on different databases that -are being used in production. - -This file can be imported as a module and contains the following functions: - - * get_trino_conn_details - * get_trino_engine - * trino_execute_query - * trino_copy_table_to_iceberg - * trino_create_schema - * trino_create_table_from_external_parquet_file - * trino_insert_values -""" import logging import sqlalchemy.engine +from ..utils.s3 import validate_s3_key +from ..utils.sql import validate_column, validate_identifier + logger = logging.getLogger(__name__) @@ -61,10 +49,10 @@ def get_trino_engine(trino_conn_details: dict) -> sqlalchemy.engine.Engine: port = trino_conn_details['port'] database = trino_conn_details['database'] - logger.debug(f"username={username}") - logger.debug(f"host={host}") - logger.debug(f"port={port}") - logger.debug(f"database={database}") + logger.info(f"username={username}") + logger.info(f"host={host}") + logger.info(f"port={port}") + logger.info(f"database={database}") engine = create_engine( f"trino://{username}@{host}:{port}/{database}", @@ -72,16 +60,79 @@ def get_trino_engine(trino_conn_details: dict) -> sqlalchemy.engine.Engine: "http_scheme": "https", # TODO This needs to be set to true when deploying to anything thats not dev "verify": False - } + }, + echo=True ) return engine -def trino_execute_query(engine: sqlalchemy.engine.Engine, query: str) -> None: +def trino_execute_query(engine: sqlalchemy.engine.Engine, query: str, **kwargs) -> None: """Executes SQL based on a provided query. :param query: String for the SQL query that will be executed :param engine: SqlAlchemy engine object for communicating to a database """ - engine.execute(query) - logger.info("Trino query successfully executed") + try: + logger.info("Trino query executing...") + logger.info(f"query={query}") + engine.execute(query, **kwargs) + logger.info("Trino query success!") + + except Exception as ex: + logger.exception("Trino query encountered an error!", exc_info=ex) + raise ex + + +def create_schema(trino: sqlalchemy.engine.Engine, schema: str, location: str): + query = f"CREATE SCHEMA IF NOT EXISTS " \ + f"{validate_identifier(schema)} " \ + f"WITH (" \ + f"location='s3a://{validate_s3_key(location)}/'" \ + f")" + trino.execute(query) + + +def drop_schema(trino: sqlalchemy.engine.Engine, schema: str): + query = f"DROP SCHEMA " \ + f"{validate_identifier(schema)}" + trino.execute(query) + + +def hive_create_table_from_csv(trino: sqlalchemy.engine.Engine, table: str, columns: list, location: str): + schema = ", ".join(map(lambda col: f"{validate_column(col)} VARCHAR", columns)) + query = f"CREATE TABLE " \ + f"{validate_identifier(table)} ({schema}) " \ + f"WITH (" \ + f"external_location='s3a://{validate_s3_key(location)}', " \ + f"skip_header_line_count=1, " \ + f"format='CSV'" \ + f")" + trino.execute(query) + + +def hive_create_table_from_parquet(trino: sqlalchemy.engine.Engine, table: str, location: str): + query = f"CREATE TABLE " \ + f"{validate_identifier(table)} " \ + f"WITH (" \ + f"external_location='s3a://{validate_s3_key(location)}', " \ + f"format='PARQUET'" \ + f")" + trino.execute(query) + + +def iceberg_create_table_from_hive(trino: sqlalchemy.engine.Engine, table: str, hive_table: str, location: str): + + query = f"CREATE TABLE " \ + f"{validate_identifier(table)} " \ + f"WITH (" \ + f"location='s3a://{validate_s3_key(location)}/', " \ + f"format='PARQUET'" \ + f") " \ + f"AS SELECT * FROM {validate_identifier(hive_table)}" + trino.execute(query) + + +def drop_table(trino: sqlalchemy.engine.Engine, table: str): + query = f"DROP TABLE " \ + f"{validate_identifier(table)}" + trino.execute(query) diff --git a/dags/modules/providers/operators/rabbitmq.py b/dags/modules/providers/operators/rabbitmq.py index 24807200..492a4240 100644 --- a/dags/modules/providers/operators/rabbitmq.py +++ b/dags/modules/providers/operators/rabbitmq.py @@ -56,7 +56,7 @@ def execute(self, context, event: dict[str, Any] | None = None): if event["status"] == "success": logger.info("Consumed message!") - logger.debug(f"message={event['message']}") + logger.info(f"message={event['message']}") # Process the deferred message! last_message_utc = timezone.utcnow() @@ -76,7 +76,7 @@ def execute(self, context, event: dict[str, Any] | None = None): message = self.hook.pull(self.queue_name) if message is not None: logger.info("Consumed message!") - logger.debug(f"message={message}") + logger.info(f"message={message}") # Process the message! last_message_utc = timezone.utcnow() diff --git a/dags/modules/utils/__init__.py b/dags/modules/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dags/modules/utils/s3.py b/dags/modules/utils/s3.py new file mode 100644 index 00000000..5a0568b4 --- /dev/null +++ b/dags/modules/utils/s3.py @@ -0,0 +1,59 @@ +import re +import json +import logging +import boto3 +import s3fs +from airflow.hooks.base import BaseHook +from botocore.config import Config + +logger = logging.getLogger(__name__) + + +def validate_s3_key(key): + # Validate the s3 key is strictly one or more slash separated keys + assert re.match( + r'^(?:[a-z0-9\-_]+)(?:/(?:[a-z0-9\-_]+))*$', + key, + flags=re.IGNORECASE + ) + + return key + + +def s3_get_resource(conn_id: str): + s3_conn = json.loads(BaseHook.get_connection(conn_id).get_extra()) + + return boto3.resource( + 's3', + aws_access_key_id=s3_conn["aws_access_key_id"], + aws_secret_access_key=s3_conn["aws_secret_access_key"], + endpoint_url=s3_conn["endpoint_url"], + config=Config(signature_version='s3v4') + ) + + +def s3_get_fs(conn_id): + + s3_conn = json.loads(BaseHook.get_connection(conn_id).get_extra()) + + return s3fs.S3FileSystem( + endpoint_url=s3_conn["endpoint_url"], + key=s3_conn["aws_access_key_id"], + secret=s3_conn["aws_secret_access_key"], + use_ssl=False + ) + + +def s3_copy(conn_id: str, src_bucket, dst_bucket, src_key, dst_key, move=False): + + s3 = s3_get_resource(conn_id) + + s3.Bucket(dst_bucket).copy({'Bucket': src_bucket, 'Key': src_key}, dst_key) + + if move: + s3.Object(src_bucket, src_key).delete() + + +def s3_delete(conn_id: str, bucket, key): + + s3_get_resource(conn_id).Object(bucket, key).delete() diff --git a/dags/modules/utils/sha1.py b/dags/modules/utils/sha1.py new file mode 100644 index 00000000..01d406cd --- /dev/null +++ b/dags/modules/utils/sha1.py @@ -0,0 +1,7 @@ +import hashlib + + +def sha1(value): + sha_1 = hashlib.sha1() + sha_1.update(str(value).encode('utf-8')) + return sha_1.hexdigest() diff --git a/dags/modules/utils/sql.py b/dags/modules/utils/sql.py new file mode 100644 index 00000000..88c5e3e6 --- /dev/null +++ b/dags/modules/utils/sql.py @@ -0,0 +1,27 @@ +import re + + +def escape_column(column): + return re.sub(r'[^a-zA-Z0-9]+', '_', column).strip("_") + + +def validate_column(column): + assert column == escape_column(column) + return column + + +def escape_dataset(dataset): + # TODO make this more sensible + return escape_column(dataset) + + +def validate_identifier(identifier): + # Validate the identifier is strictly one or more dot separated identifiers + assert re.match( + r'^(?:[a-z](?:[_a-z0-9]*[a-z0-9]|[a-z0-9]*)' + r'(?:\.[a-z](?:[_a-z0-9]*[a-z0-9]|[a-z0-9]*)))*$', + identifier, + flags=re.IGNORECASE + ) + + return identifier