Skip to content

Commit

Permalink
fix(operators): fix import error
Browse files Browse the repository at this point in the history
  • Loading branch information
timonviola committed Oct 23, 2024
1 parent 5e801ae commit 63e5acc
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 46 deletions.
4 changes: 2 additions & 2 deletions src/dagcellent/data_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

# pyright: reportUnknownVariableType=false
from dagcellent.data_utils.sql_reflection import (
Pyarrow2redshift,
Query,
UnsupportedType,
pyarrow2redshift,
reflect_meta_data,
reflect_select_query,
)

__all__ = [
"Query",
"UnsupportedType",
"pyarrow2redshift",
"Pyarrow2redshift",
"reflect_meta_data",
"reflect_select_query",
]
100 changes: 62 additions & 38 deletions src/dagcellent/data_utils/sql_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
import warnings
from typing import TYPE_CHECKING, TypeAlias
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias

import pyarrow as pa
from sqlalchemy import (
Expand All @@ -26,43 +26,67 @@ class UnsupportedType(Exception):
Query: TypeAlias = str


def pyarrow2redshift(dtype: pa.DataType, string_type: str) -> str:
"""Pyarrow to Redshift data types conversion."""
if pa.types.is_int8(dtype):
return "SMALLINT"
if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype):
return "SMALLINT"
if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype):
return "INTEGER"
if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype):
return "BIGINT"
if pa.types.is_uint64(dtype):
raise UnsupportedType(
"There is no support for uint64, please consider int64 or uint32."
)
if pa.types.is_float32(dtype):
return "FLOAT4"
if pa.types.is_float64(dtype):
return "FLOAT8"
if pa.types.is_boolean(dtype):
return "BOOL"
if pa.types.is_string(dtype) or pa.types.is_large_string(dtype):
return string_type
if pa.types.is_timestamp(dtype):
return "TIMESTAMP"
if pa.types.is_date(dtype):
return "DATE"
if pa.types.is_time(dtype):
return "TIME"
if pa.types.is_binary(dtype):
return "VARBYTE"
if pa.types.is_decimal(dtype):
return f"DECIMAL({dtype.precision},{dtype.scale})" # type: ignore[reportAttributeAccessIssue]
if pa.types.is_dictionary(dtype):
return pyarrow2redshift(dtype=dtype.value_type, string_type=string_type) # type: ignore[reportUnknownMember]
if pa.types.is_list(dtype) or pa.types.is_struct(dtype) or pa.types.is_map(dtype):
return "SUPER"
raise UnsupportedType(f"Unsupported Redshift type: {dtype}")
class PyarrowMapping(Protocol):
"""Map Pyarrow types to target system's types."""

def map(self: PyarrowMapping, dtype: pa.DataType, *args: Any, **kwargs: Any) -> str:
"""Type mapping interface.
Args:
self:
dtype:
"""
...


class Pyarrow2redshift:
"""Map Apache Arrow to Redshift."""

@classmethod
def map(cls: PyarrowMapping, dtype: pa.DataType, string_type: str) -> str:
"""Pyarrow to Redshift data types conversion."""
if pa.types.is_int8(dtype):
return "SMALLINT"
if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype):
return "SMALLINT"
if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype):
return "INTEGER"
if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype):
return "BIGINT"
if pa.types.is_uint64(dtype):
raise UnsupportedType(
"There is no support for uint64, please consider int64 or uint32."
)
if pa.types.is_float32(dtype):
return "FLOAT4"
if pa.types.is_float64(dtype):
return "FLOAT8"
if pa.types.is_boolean(dtype):
return "BOOL"
if pa.types.is_string(dtype) or pa.types.is_large_string(dtype):
return string_type
if pa.types.is_timestamp(dtype):
return "TIMESTAMP"
if pa.types.is_date(dtype):
return "DATE"
if pa.types.is_time(dtype):
return "TIME"
if pa.types.is_binary(dtype):
return "VARBYTE"
if pa.types.is_decimal(dtype):
return f"DECIMAL({dtype.precision},{dtype.scale})" # type: ignore[reportAttributeAccessIssue]
if pa.types.is_dictionary(dtype):
return cls.map(dtype=dtype.value_type, string_type=string_type) # type: ignore[reportUnknownMember]
if (
pa.types.is_list(dtype)
or pa.types.is_struct(dtype)
or pa.types.is_map(dtype)
):
return "SUPER"
if pa.types.is_null(dtype):
warnings.warn("Experimental NULL value derived.", stacklevel=2)
return "NULL"
raise UnsupportedType(f"Unsupported Redshift type: {dtype}")


def drop_unsupported_dtypes(table: Table) -> Table:
Expand Down
17 changes: 11 additions & 6 deletions src/dagcellent/operators/sql_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from sqlalchemy import create_engine, text

from dagcellent.data_utils.sql_reflection import (
Pyarrow2redshift,
PyarrowMapping,
Query,
pyarrow2redshift,
reflect_meta_data,
reflect_select_query,
)
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
fix_dtypes: bool = True,
where_clause: str | None = None,
join_clause: str | None = None,
type_mapping: PyarrowMapping = Pyarrow2redshift,
**kwargs: Any,
) -> None: # type: ignore
"""Override constructor with extra chunksize argument."""
Expand All @@ -126,6 +128,7 @@ def __init__(
self.table_name = table_name
self.where_clause = where_clause
self.join_clause = join_clause
self.type_mapping = type_mapping
self.log.setLevel("NOTSET")

def _supported_source_connections(self) -> list[str]:
Expand Down Expand Up @@ -164,15 +167,17 @@ def _create_select_query(self) -> Query:
return select_ddl

def _get_pandas_data(self, sql: str) -> Iterable[pd.DataFrame]:

import pandas.io.sql as psql

sql_hook = self._get_hook()

# NOTE pd type annotations are net strict enough
with closing(sql_hook.get_conn()) as conn: # type: ignore
yield from psql.read_sql( # type: ignore
sql, con=conn, params=self.params, chunksize=self.chunksize # type: ignore
sql,
con=conn,
params=self.params,
chunksize=self.chunksize, # type: ignore
)

def _clean_s3_folder(self, s3_hook: S3Hook, path: str) -> bool:
Expand Down Expand Up @@ -321,11 +326,11 @@ def execute(self, context: Context) -> None:
# Get schema and xcom only once
# NOTE pyarrow did not type hint the read_schema method
s = pq.read_schema(tmp_file.name) # type: ignore[UnknownTypeMember]
redshift_map = { # type: ignore[UnknownTypeMember]
k: pyarrow2redshift(v, "VARCHAR") # type: ignore[no-untyped-call]
_mapped_type = { # type: ignore[UnknownTypeMember]
k: self.type_mapping.map(v, "VARCHAR") # type: ignore[no-untyped-call]
for k, v in zip(s.names, s.types) # type: ignore[no-untyped-call]
}
self.xcom_push(context, key="type_map", value=redshift_map)
self.xcom_push(context, key="type_map", value=_mapped_type)
first_run = False

self.log.info("Uploading data to S3")
Expand Down

0 comments on commit 63e5acc

Please sign in to comment.