Skip to content

Commit

Permalink
refactor: Use functools.reduct approach to select best SQL type
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Jan 19, 2024
1 parent 634a2cc commit e3f426e
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from functools import lru_cache, partial, reduce

import simplejson
import sqlalchemy as sa
Expand Down Expand Up @@ -867,30 +867,36 @@ def merge_sql_types(
Raises:
ValueError: If sql_types argument has zero members.
"""
if not sql_types:
msg = "Expected at least one member in `sql_types` argument."
raise ValueError(msg)
sorted_types = self._sort_types(sql_types)
result = reduce(partial(self.select_type, sql_types[0]), sorted_types)
if result is not None:
return result

msg = f"Unable to merge sql types: {', '.join([str(t) for t in sorted_types])}"
raise ValueError(msg)

if len(sql_types) == 1:
return sql_types[0]
def select_type(
self,
current_type: sa.types.TypeEngine,
type1: sa.types.TypeEngine,
type2: sa.types.TypeEngine,
) -> sa.types.TypeEngine | None:
"""Return the best conversion class for the given types.
Args:
current_type: The current type.
type1: The first type to compare.
type2: The second type to compare.
Returns:
The best conversion class for the given types.
"""
# Gathering Type to match variables
# sent in _adapt_column_type
current_type = sql_types[0]
cur_len: int = getattr(current_type, "length", 0)

# Convert the two types given into a sorted list
# containing the best conversion classes
sql_types = self._sort_types(sql_types)

# If greater than two evaluate the first pair then on down the line
if len(sql_types) > 2: # noqa: PLR2004
return self.merge_sql_types(
[self.merge_sql_types([sql_types[0], sql_types[1]]), *sql_types[2:]],
)
cur_len: int = getattr(current_type, "length", 0)

# Get the generic type class
for opt in sql_types:
for opt in (type1, type2):
# Get the length
opt_len: int = getattr(opt, "length", 0)
generic_type = type(opt.as_generic())
Expand All @@ -912,8 +918,7 @@ def merge_sql_types(
elif str(opt) == str(current_type):
return opt

msg = f"Unable to merge sql types: {', '.join([str(t) for t in sql_types])}"
raise ValueError(msg)
return None

def _sort_types(
self,
Expand Down Expand Up @@ -1131,7 +1136,7 @@ def _adapt_column_type(
return

# Not the same type, generic type or compatible types
# calling merge_sql_types for assistnace
# calling merge_sql_types for assistance
compatible_sql_type = self.merge_sql_types([current_type, sql_type])

if str(compatible_sql_type) == str(current_type):
Expand Down

0 comments on commit e3f426e

Please sign in to comment.