From 63e5accc0094ea05b4dfcc4a0e41896fa2c43c02 Mon Sep 17 00:00:00 2001 From: Timon Viola Date: Wed, 23 Oct 2024 14:30:39 +0200 Subject: [PATCH] fix(operators): fix import error --- src/dagcellent/data_utils/__init__.py | 4 +- src/dagcellent/data_utils/sql_reflection.py | 100 ++++++++++++-------- src/dagcellent/operators/sql_s3.py | 17 ++-- 3 files changed, 75 insertions(+), 46 deletions(-) diff --git a/src/dagcellent/data_utils/__init__.py b/src/dagcellent/data_utils/__init__.py index 0d5dd3b..5761958 100644 --- a/src/dagcellent/data_utils/__init__.py +++ b/src/dagcellent/data_utils/__init__.py @@ -3,9 +3,9 @@ # pyright: reportUnknownVariableType=false from dagcellent.data_utils.sql_reflection import ( + Pyarrow2redshift, Query, UnsupportedType, - pyarrow2redshift, reflect_meta_data, reflect_select_query, ) @@ -13,7 +13,7 @@ __all__ = [ "Query", "UnsupportedType", - "pyarrow2redshift", + "Pyarrow2redshift", "reflect_meta_data", "reflect_select_query", ] diff --git a/src/dagcellent/data_utils/sql_reflection.py b/src/dagcellent/data_utils/sql_reflection.py index babf051..7be4240 100644 --- a/src/dagcellent/data_utils/sql_reflection.py +++ b/src/dagcellent/data_utils/sql_reflection.py @@ -3,7 +3,7 @@ import logging import warnings -from typing import TYPE_CHECKING, TypeAlias +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias import pyarrow as pa from sqlalchemy import ( @@ -26,43 +26,67 @@ class UnsupportedType(Exception): Query: TypeAlias = str -def pyarrow2redshift(dtype: pa.DataType, string_type: str) -> str: - """Pyarrow to Redshift data types conversion.""" - if pa.types.is_int8(dtype): - return "SMALLINT" - if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): - return "SMALLINT" - if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): - return "INTEGER" - if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): - return "BIGINT" - if pa.types.is_uint64(dtype): - raise UnsupportedType( - "There is no support for uint64, please consider int64 or uint32." - ) - if pa.types.is_float32(dtype): - return "FLOAT4" - if pa.types.is_float64(dtype): - return "FLOAT8" - if pa.types.is_boolean(dtype): - return "BOOL" - if pa.types.is_string(dtype) or pa.types.is_large_string(dtype): - return string_type - if pa.types.is_timestamp(dtype): - return "TIMESTAMP" - if pa.types.is_date(dtype): - return "DATE" - if pa.types.is_time(dtype): - return "TIME" - if pa.types.is_binary(dtype): - return "VARBYTE" - if pa.types.is_decimal(dtype): - return f"DECIMAL({dtype.precision},{dtype.scale})" # type: ignore[reportAttributeAccessIssue] - if pa.types.is_dictionary(dtype): - return pyarrow2redshift(dtype=dtype.value_type, string_type=string_type) # type: ignore[reportUnknownMember] - if pa.types.is_list(dtype) or pa.types.is_struct(dtype) or pa.types.is_map(dtype): - return "SUPER" - raise UnsupportedType(f"Unsupported Redshift type: {dtype}") +class PyarrowMapping(Protocol): + """Map Pyarrow types to target system's types.""" + + def map(self: PyarrowMapping, dtype: pa.DataType, *args: Any, **kwargs: Any) -> str: + """Type mapping interface. + + Args: + self: + dtype: + """ + ... + + +class Pyarrow2redshift: + """Map Apache Arrow to Redshift.""" + + @classmethod + def map(cls: PyarrowMapping, dtype: pa.DataType, string_type: str) -> str: + """Pyarrow to Redshift data types conversion.""" + if pa.types.is_int8(dtype): + return "SMALLINT" + if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): + return "SMALLINT" + if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): + return "INTEGER" + if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): + return "BIGINT" + if pa.types.is_uint64(dtype): + raise UnsupportedType( + "There is no support for uint64, please consider int64 or uint32." + ) + if pa.types.is_float32(dtype): + return "FLOAT4" + if pa.types.is_float64(dtype): + return "FLOAT8" + if pa.types.is_boolean(dtype): + return "BOOL" + if pa.types.is_string(dtype) or pa.types.is_large_string(dtype): + return string_type + if pa.types.is_timestamp(dtype): + return "TIMESTAMP" + if pa.types.is_date(dtype): + return "DATE" + if pa.types.is_time(dtype): + return "TIME" + if pa.types.is_binary(dtype): + return "VARBYTE" + if pa.types.is_decimal(dtype): + return f"DECIMAL({dtype.precision},{dtype.scale})" # type: ignore[reportAttributeAccessIssue] + if pa.types.is_dictionary(dtype): + return cls.map(dtype=dtype.value_type, string_type=string_type) # type: ignore[reportUnknownMember] + if ( + pa.types.is_list(dtype) + or pa.types.is_struct(dtype) + or pa.types.is_map(dtype) + ): + return "SUPER" + if pa.types.is_null(dtype): + warnings.warn("Experimental NULL value derived.", stacklevel=2) + return "NULL" + raise UnsupportedType(f"Unsupported Redshift type: {dtype}") def drop_unsupported_dtypes(table: Table) -> Table: diff --git a/src/dagcellent/operators/sql_s3.py b/src/dagcellent/operators/sql_s3.py index 599e104..40c910c 100644 --- a/src/dagcellent/operators/sql_s3.py +++ b/src/dagcellent/operators/sql_s3.py @@ -24,8 +24,9 @@ from sqlalchemy import create_engine, text from dagcellent.data_utils.sql_reflection import ( + Pyarrow2redshift, + PyarrowMapping, Query, - pyarrow2redshift, reflect_meta_data, reflect_select_query, ) @@ -113,6 +114,7 @@ def __init__( fix_dtypes: bool = True, where_clause: str | None = None, join_clause: str | None = None, + type_mapping: PyarrowMapping = Pyarrow2redshift, **kwargs: Any, ) -> None: # type: ignore """Override constructor with extra chunksize argument.""" @@ -126,6 +128,7 @@ def __init__( self.table_name = table_name self.where_clause = where_clause self.join_clause = join_clause + self.type_mapping = type_mapping self.log.setLevel("NOTSET") def _supported_source_connections(self) -> list[str]: @@ -164,7 +167,6 @@ def _create_select_query(self) -> Query: return select_ddl def _get_pandas_data(self, sql: str) -> Iterable[pd.DataFrame]: - import pandas.io.sql as psql sql_hook = self._get_hook() @@ -172,7 +174,10 @@ def _get_pandas_data(self, sql: str) -> Iterable[pd.DataFrame]: # NOTE pd type annotations are net strict enough with closing(sql_hook.get_conn()) as conn: # type: ignore yield from psql.read_sql( # type: ignore - sql, con=conn, params=self.params, chunksize=self.chunksize # type: ignore + sql, + con=conn, + params=self.params, + chunksize=self.chunksize, # type: ignore ) def _clean_s3_folder(self, s3_hook: S3Hook, path: str) -> bool: @@ -321,11 +326,11 @@ def execute(self, context: Context) -> None: # Get schema and xcom only once # NOTE pyarrow did not type hint the read_schema method s = pq.read_schema(tmp_file.name) # type: ignore[UnknownTypeMember] - redshift_map = { # type: ignore[UnknownTypeMember] - k: pyarrow2redshift(v, "VARCHAR") # type: ignore[no-untyped-call] + _mapped_type = { # type: ignore[UnknownTypeMember] + k: self.type_mapping.map(v, "VARCHAR") # type: ignore[no-untyped-call] for k, v in zip(s.names, s.types) # type: ignore[no-untyped-call] } - self.xcom_push(context, key="type_map", value=redshift_map) + self.xcom_push(context, key="type_map", value=_mapped_type) first_run = False self.log.info("Uploading data to S3")