Skip to content

Commit

Permalink
fix: correctly detect the connection protocol for minio and trino bas…
Browse files Browse the repository at this point in the history
…ed on the connection details
  • Loading branch information
alee-x committed Jun 14, 2024
1 parent 7586eeb commit c108861
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 50 deletions.
6 changes: 3 additions & 3 deletions dags/modules/convert/ingest_csv_to_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dags.modules.databases.duckdb import file_csv_to_parquet
from dags.modules.databases.trino import create_schema, drop_table, get_trino_conn_details, get_trino_engine, hive_create_table_from_parquet, iceberg_create_table_from_hive, validate_identifier, validate_s3_key, get_table_schema_and_max_values
from dags.modules.utils.rabbit import send_message_to_rabbitmq
from dags.modules.utils.s3 import s3_create_bucket, s3_delete, s3_download_minio, s3_upload
from dags.modules.utils.s3 import s3_create_bucket, s3_delete, s3_download_minio, s3_upload, detect_if_secure_endpoint, get_conn_details
from dags.modules.utils.tracking_timer import tracking_timer, tracking_data, tracking_data_str

import constants
Expand Down Expand Up @@ -297,13 +297,13 @@ def ingest_csv_to_iceberg(dataset, tablename, version, label, etag, ingest_bucke
x=tracking_timer(p_conn, etag, "s_schema")

logging.info("Getting schema from the new PAR file")
s3_conn = json.loads(BaseHook.get_connection("s3_conn").get_extra())
s3_conn = get_conn_details("s3_conn")

fs = s3fs.S3FileSystem(
endpoint_url=s3_conn["endpoint_url"],
key=s3_conn["aws_access_key_id"],
secret=s3_conn["aws_secret_access_key"],
use_ssl=True
use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"])
)

with fs.open(F"s3://{hive_bucket}/{hive_key}", "rb") as fp:
Expand Down
5 changes: 3 additions & 2 deletions dags/modules/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import duckdb
from airflow.hooks.base import BaseHook

from ..utils.s3 import validate_s3_key, get_conn_details
from ..utils.s3 import validate_s3_key, get_conn_details, detect_if_secure_endpoint

logger = logging.getLogger(__name__)

Expand All @@ -20,6 +20,7 @@ def s3_csv_to_parquet(conn_id: str, src_bucket: str, dst_bucket: str, src_key: s
access_key_id = s3_conn['aws_access_key_id']
secret_access_key = s3_conn['aws_secret_access_key']
endpoint = s3_conn["endpoint_url"].replace("http://", "").replace("https://", "")
is_secure = detect_if_secure_endpoint(s3_conn["endpoint_url"])

# original duckdb
# con = duckdb.connect(database=':memory:')
Expand All @@ -33,7 +34,7 @@ def s3_csv_to_parquet(conn_id: str, src_bucket: str, dst_bucket: str, src_key: s
f"SET s3_endpoint='{endpoint}';" \
f"SET s3_access_key_id='{access_key_id}';" \
f"SET s3_secret_access_key='{secret_access_key}';" \
f"SET s3_use_ssl=True;" \
f"SET s3_use_ssl={is_secure};" \
f"SET preserve_insertion_order=False;" \
f"SET s3_url_style='path';" \
f"SET memory_limit='{memory}GB'"
Expand Down
45 changes: 24 additions & 21 deletions dags/modules/databases/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def get_trino_engine(trino_conn_details: dict) -> sqlalchemy.engine.Engine:
host = trino_conn_details['host']
port = trino_conn_details['port']
database = trino_conn_details['database']
connect_protocol = "http"
if port == 443:
connect_protocol = "https"


logger.info(f"username={username}")
logger.info(f"host={host}")
Expand All @@ -59,7 +63,7 @@ def get_trino_engine(trino_conn_details: dict) -> sqlalchemy.engine.Engine:
engine = create_engine(
f"trino://{username}@{host}:{port}/{database}",
connect_args={
"http_scheme": "http",
"http_scheme": connect_protocol,
# TODO This needs to be set to true when deploying to anything thats not dev
"verify": False
},
Expand Down Expand Up @@ -165,31 +169,30 @@ def drop_table(trino: sqlalchemy.engine.Engine, table: str):
trino.execute(query)

def get_schema(engine, table_name):
query = f"SHOW COLUMNS FROM {table_name}"
query = f"SHOW COLUMNS FROM {table_name}"
try:
result = engine.execute(text(query))
schema = {row['Column']: row['Type'] for row in result}
return schema
except SQLAlchemyError as e:
print(f"Error executing query: {e}")
return None

def get_max_values(engine, table_name, schema):
max_values = {}
for column in schema.keys():
query = f"SELECT MAX({column}) as max_value FROM {table_name}"
if schema[column] == "varchar":
query = f"SELECT MAX(length({column})) as max_length FROM {table_name}"
try:
result = engine.execute(text(query))
schema = {row['Column']: row['Type'] for row in result}
return schema
result = engine.execute(text(query)).scalar()
max_values[column] = result
except SQLAlchemyError as e:
print(f"Error executing query: {e}")
return None

def get_max_values(engine, table_name, schema):
max_values = {}
for column in schema.keys():
query = f"SELECT MAX({column}) as max_value FROM {table_name}"
if schema[column] == "varchar":
query = f"SELECT MAX(length({column})) as max_length FROM {table_name}"
try:
result = engine.execute(text(query)).scalar()
max_values[column] = result
except SQLAlchemyError as e:
print(f"Error executing query: {e}")
max_values[column] = None
return max_values
max_values[column] = None
return max_values

def get_table_schema_and_max_values(trino: sqlalchemy.engine.Engine, table_name ):

# Reflect the table from the database
schema = get_schema(trino,table_name)

Expand Down
51 changes: 27 additions & 24 deletions dags/modules/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@
import logging
import boto3
from minio import Minio
#from minio.error import ResponseError
# from minio.error import ResponseError
import s3fs
from airflow.hooks.base import BaseHook
from botocore.config import Config

logger = logging.getLogger(__name__)


def detect_if_secure_endpoint(endpoint_url: str) -> bool:
https_re = r'^https://'

if re.match(https_re, endpoint_url):
return True

return False


def get_conn_details(conn_id) -> dict:
"""Gets S3 connection info from Airflow connection with given connection id
Expand All @@ -23,9 +33,11 @@ def get_conn_details(conn_id) -> dict:
logger.critical("There is no Airflow connection configured with the name {0}.".format(conn_id))
raise e
# Transform connection object into dictionary to make it easier to share between modules
s3_conn = {"endpoint_url": conn_extra["endpoint_url"], "aws_access_key_id": conn.login, "aws_secret_access_key": conn.password}
s3_conn = {"endpoint_url": conn_extra["endpoint_url"], "aws_access_key_id": conn.login,
"aws_secret_access_key": conn.password}
return s3_conn


def validate_s3_key(key):
# Validate the s3 key is strictly one or more slash separated keys
logging.info(f"Validate key: {key}")
Expand All @@ -51,19 +63,17 @@ def s3_get_resource(conn_id: str):


def s3_get_fs(conn_id):

s3_conn = get_conn_details(conn_id)

return s3fs.S3FileSystem(
endpoint_url=s3_conn["endpoint_url"],
key=s3_conn["aws_access_key_id"],
secret=s3_conn["aws_secret_access_key"],
use_ssl=True
use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"])
)


def s3_copy(conn_id: str, src_bucket, dst_bucket, src_key, dst_key, move=False):

s3 = s3_get_resource(conn_id)

s3.Bucket(dst_bucket).copy({'Bucket': src_bucket, 'Key': src_key}, dst_key)
Expand All @@ -73,7 +83,6 @@ def s3_copy(conn_id: str, src_bucket, dst_bucket, src_key, dst_key, move=False):


def s3_delete(conn_id: str, bucket, key):

s3_get_resource(conn_id).Object(bucket, key).delete()


Expand All @@ -84,22 +93,19 @@ def s3_create_bucket(conn_id: str, bucket):
print(f"An error occurred while creating the S3 bucket: {e}")



def s3_download_minio(conn_id, bucket_name, object_name, local_file_path):

s3_conn = get_conn_details(conn_id)

url = str(s3_conn["endpoint_url"]).replace('http://','').replace('https://','')
url = str(s3_conn["endpoint_url"]).replace('http://', '').replace('https://', '')

client = Minio(url,
access_key=s3_conn["aws_access_key_id"],
secret_key=s3_conn["aws_secret_access_key"],
secure=True)


#client.fget_object(bucket_name, object_name, local_file_path)

#try:
access_key=s3_conn["aws_access_key_id"],
secret_key=s3_conn["aws_secret_access_key"],
secure=detect_if_secure_endpoint(s3_conn["endpoint_url"]))

# client.fget_object(bucket_name, object_name, local_file_path)

# try:
# Start a multipart download
response = client.get_object(
bucket_name,
Expand All @@ -115,35 +121,32 @@ def s3_download_minio(conn_id, bucket_name, object_name, local_file_path):

print("Download successful!")

#except ResponseError as err:
# except ResponseError as err:
# print(err)


def s3_download(conn_id, bucket_name, object_name, local_file_path):

s3_conn = get_conn_details(conn_id)

client = boto3.client(
's3',
endpoint_url=s3_conn["endpoint_url"],
aws_access_key_id=s3_conn["aws_access_key_id"],
aws_secret_access_key=s3_conn["aws_secret_access_key"],
use_ssl=True )
use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"]))

with open(local_file_path, 'wb') as f:
client.download_fileobj(bucket_name, object_name, f)


def s3_upload(conn_id, src_file, bucket, object_name):

s3_conn = get_conn_details(conn_id)

client = boto3.client(
's3',
endpoint_url=s3_conn["endpoint_url"],
aws_access_key_id=s3_conn["aws_access_key_id"],
aws_secret_access_key=s3_conn["aws_secret_access_key"],
use_ssl=True )
use_ssl=detect_if_secure_endpoint(s3_conn["endpoint_url"]))

client.upload_file(src_file,bucket,object_name)

client.upload_file(src_file, bucket, object_name)

0 comments on commit c108861

Please sign in to comment.