From e3f426ec3181f57336b5d105e184063f70494d9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= Date: Wed, 20 Dec 2023 19:02:57 -0600 Subject: [PATCH] refactor: Use `functools.reduct` approach to select best SQL type --- singer_sdk/connectors/sql.py | 49 ++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index 2263516bce..b2a3f57081 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -9,7 +9,7 @@ import warnings from contextlib import contextmanager from datetime import datetime -from functools import lru_cache +from functools import lru_cache, partial, reduce import simplejson import sqlalchemy as sa @@ -867,30 +867,36 @@ def merge_sql_types( Raises: ValueError: If sql_types argument has zero members. """ - if not sql_types: - msg = "Expected at least one member in `sql_types` argument." - raise ValueError(msg) + sorted_types = self._sort_types(sql_types) + result = reduce(partial(self.select_type, sql_types[0]), sorted_types) + if result is not None: + return result + + msg = f"Unable to merge sql types: {', '.join([str(t) for t in sorted_types])}" + raise ValueError(msg) - if len(sql_types) == 1: - return sql_types[0] + def select_type( + self, + current_type: sa.types.TypeEngine, + type1: sa.types.TypeEngine, + type2: sa.types.TypeEngine, + ) -> sa.types.TypeEngine | None: + """Return the best conversion class for the given types. + Args: + current_type: The current type. + type1: The first type to compare. + type2: The second type to compare. + + Returns: + The best conversion class for the given types. + """ # Gathering Type to match variables # sent in _adapt_column_type - current_type = sql_types[0] - cur_len: int = getattr(current_type, "length", 0) - # Convert the two types given into a sorted list - # containing the best conversion classes - sql_types = self._sort_types(sql_types) - - # If greater than two evaluate the first pair then on down the line - if len(sql_types) > 2: # noqa: PLR2004 - return self.merge_sql_types( - [self.merge_sql_types([sql_types[0], sql_types[1]]), *sql_types[2:]], - ) + cur_len: int = getattr(current_type, "length", 0) - # Get the generic type class - for opt in sql_types: + for opt in (type1, type2): # Get the length opt_len: int = getattr(opt, "length", 0) generic_type = type(opt.as_generic()) @@ -912,8 +918,7 @@ def merge_sql_types( elif str(opt) == str(current_type): return opt - msg = f"Unable to merge sql types: {', '.join([str(t) for t in sql_types])}" - raise ValueError(msg) + return None def _sort_types( self, @@ -1131,7 +1136,7 @@ def _adapt_column_type( return # Not the same type, generic type or compatible types - # calling merge_sql_types for assistnace + # calling merge_sql_types for assistance compatible_sql_type = self.merge_sql_types([current_type, sql_type]) if str(compatible_sql_type) == str(current_type):