From 80425d9e81980a35fc02cf8a46db9a0ed3fa6d39 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 13 Nov 2024 18:10:20 +0300 Subject: [PATCH] Attempt to support sqlalchemy 1.4+ --- test/test_core.py | 36 +- test/test_suite.py | 2 + tox.ini | 1 + ydb_sqlalchemy/sqlalchemy/__init__.py | 566 ++---------------- .../sqlalchemy/compiler/__init__.py | 16 + ydb_sqlalchemy/sqlalchemy/compiler/base.py | 320 ++++++++++ ydb_sqlalchemy/sqlalchemy/compiler/sa14.py | 208 +++++++ ydb_sqlalchemy/sqlalchemy/compiler/sa20.py | 252 ++++++++ ydb_sqlalchemy/sqlalchemy/datetime_types.py | 6 +- ydb_sqlalchemy/sqlalchemy/types.py | 9 +- 10 files changed, 861 insertions(+), 555 deletions(-) create mode 100644 ydb_sqlalchemy/sqlalchemy/compiler/__init__.py create mode 100644 ydb_sqlalchemy/sqlalchemy/compiler/base.py create mode 100644 ydb_sqlalchemy/sqlalchemy/compiler/sa14.py create mode 100644 ydb_sqlalchemy/sqlalchemy/compiler/sa20.py diff --git a/test/test_core.py b/test/test_core.py index 61129f8..661d3bd 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -94,7 +94,7 @@ def test_sa_crud(self, connection): (5, "c"), ] - def test_cached_query(self, connection_no_trans: sa.Connection, connection: sa.Connection): + def test_cached_query(self, connection_no_trans, connection): table = self.tables.test with connection_no_trans.begin() as transaction: @@ -263,7 +263,7 @@ def test_integer_types(self, connection): result = connection.execute(stmt).fetchone() assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64") - def test_datetime_types(self, connection: sa.Connection): + def test_datetime_types(self, connection): stmt = sa.Select( sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_datetime", datetime.datetime.now(), sa.DateTime))), sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_DATETIME", datetime.datetime.now(), sa.DATETIME))), @@ -273,7 +273,7 @@ def test_datetime_types(self, connection: sa.Connection): result = connection.execute(stmt).fetchone() assert result == (b"Timestamp", b"Datetime", b"Timestamp") - def test_datetime_types_timezone(self, connection: sa.Connection): + def test_datetime_types_timezone(self, connection): table = self.tables.test_datetime_types tzinfo = datetime.timezone(datetime.timedelta(hours=3, minutes=42)) @@ -476,7 +476,7 @@ def define_tables(cls, metadata: sa.MetaData): Column("id", Integer, primary_key=True), ) - def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection): + def test_rollback(self, connection_no_trans, connection): table = self.tables.test connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE) @@ -491,7 +491,7 @@ def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Conne result = cursor.fetchall() assert result == [] - def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection): + def test_commit(self, connection_no_trans, connection): table = self.tables.test connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE) @@ -506,9 +506,7 @@ def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connect assert set(result) == {(3,), (4,)} @pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY)) - def test_interactive_transaction( - self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level - ): + def test_interactive_transaction(self, connection_no_trans, connection, isolation_level): table = self.tables.test dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection @@ -535,9 +533,7 @@ def test_interactive_transaction( IsolationLevel.AUTOCOMMIT, ), ) - def test_not_interactive_transaction( - self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level - ): + def test_not_interactive_transaction(self, connection_no_trans, connection, isolation_level): table = self.tables.test dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection @@ -573,7 +569,7 @@ class IsolationSettings(NamedTuple): IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.QuerySnapshotReadOnly().name, True), } - def test_connection_set(self, connection_no_trans: sa.Connection): + def test_connection_set(self, connection_no_trans): dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items(): @@ -861,7 +857,7 @@ def test_insert_in_name_and_field(self, connection): class TestSecondaryIndex(TestBase): __backend__ = True - def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData): + def test_column_indexes(self, connection, metadata: sa.MetaData): table = Table( "test_column_indexes/table", metadata, @@ -884,7 +880,7 @@ def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData): index1 = indexes_map["ix_test_column_indexes_table_index_col2"] assert index1.index_columns == ["index_col2"] - def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData): + def test_async_index(self, connection, metadata: sa.MetaData): table = Table( "test_async_index/table", metadata, @@ -903,7 +899,7 @@ def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData): assert set(index.index_columns) == {"index_col1", "index_col2"} # TODO: Check type after https://github.com/ydb-platform/ydb-python-sdk/issues/351 - def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData): + def test_cover_index(self, connection, metadata: sa.MetaData): table = Table( "test_cover_index/table", metadata, @@ -922,7 +918,7 @@ def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData): assert set(index.index_columns) == {"index_col1"} # TODO: Check covered columns after https://github.com/ydb-platform/ydb-python-sdk/issues/409 - def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaData): + def test_indexes_reflection(self, connection, metadata: sa.MetaData): table = Table( "test_indexes_reflection/table", metadata, @@ -948,7 +944,7 @@ def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaDa "test_async_cover_index": {"index_col1"}, } - def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaData): + def test_index_simple_usage(self, connection, metadata: sa.MetaData): persons = Table( "test_index_simple_usage/persons", metadata, @@ -979,7 +975,7 @@ def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaDa cursor = connection.execute(select_stmt) assert cursor.scalar_one() == "Sarah Connor" - def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.MetaData): + def test_index_with_join_usage(self, connection, metadata: sa.MetaData): persons = Table( "test_index_with_join_usage/persons", metadata, @@ -1033,7 +1029,7 @@ def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.Met cursor = connection.execute(select_stmt) assert cursor.one() == ("Sarah Connor", "wanted") - def test_index_deletion(self, connection: sa.Connection, metadata: sa.MetaData): + def test_index_deletion(self, connection, metadata: sa.MetaData): persons = Table( "test_index_deletion/persons", metadata, @@ -1062,7 +1058,7 @@ def define_tables(cls, metadata: sa.MetaData): Table("table", metadata, sa.Column("id", sa.Integer, primary_key=True)) @classmethod - def insert_data(cls, connection: sa.Connection): + def insert_data(cls, connection): table = cls.tables["some_dir/nested_dir/table"] root_table = cls.tables["table"] diff --git a/test/test_suite.py b/test/test_suite.py index bf0bbad..300d2b6 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -68,6 +68,7 @@ from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest + from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest from sqlalchemy.testing.suite.test_types import StringTest as _StringTest @@ -78,6 +79,7 @@ TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) from sqlalchemy.testing.suite.test_types import TimeTest as _TimeTest + from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types diff --git a/tox.ini b/tox.ini index cd0ee2b..494bf2b 100644 --- a/tox.ini +++ b/tox.ini @@ -68,4 +68,5 @@ max-line-length = 120 ignore=E203,W503 per-file-ignores = ydb_sqlalchemy/__init__.py: F401 + ydb_sqlalchemy/sqlalchemy/compiler/__init__.py: F401 exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,docs/* diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index d3ef664..d64e09d 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -5,231 +5,25 @@ import collections import collections.abc -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Mapping, Optional, Sequence, Tuple, Union import sqlalchemy as sa import ydb from sqlalchemy import util from sqlalchemy.engine import characteristics, reflection from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect -from sqlalchemy.exc import CompileError, NoSuchTableError -from sqlalchemy.sql import ddl, functions, literal_column -from sqlalchemy.sql.compiler import ( - DDLCompiler, - IdentifierPreparer, - StrSQLCompiler, - StrSQLTypeCompiler, - selectable, -) +from sqlalchemy.exc import NoSuchTableError +from sqlalchemy.sql import functions + from sqlalchemy.sql.elements import ClauseList -from sqlalchemy.util.compat import inspect_getfullargspec -import ydb_dbapi as dbapi +import ydb_dbapi from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncConnection from ydb_sqlalchemy.sqlalchemy.dml import Upsert -from . import types - -STR_QUOTE_MAP = { - "'": "\\'", - "\\": "\\\\", - "\0": "\\0", - "\b": "\\b", - "\f": "\\f", - "\r": "\\r", - "\n": "\\n", - "\t": "\\t", - "%": "%%", -} - -COMPOUND_KEYWORDS = { - selectable.CompoundSelect.UNION: "UNION ALL", - selectable.CompoundSelect.UNION_ALL: "UNION ALL", - selectable.CompoundSelect.EXCEPT: "EXCEPT", - selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", - selectable.CompoundSelect.INTERSECT: "INTERSECT", - selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", -} - - -class YqlIdentifierPreparer(IdentifierPreparer): - reserved_words = IdentifierPreparer.reserved_words - reserved_words.update(dbapi.YDB_KEYWORDS) - - def __init__(self, dialect): - super(YqlIdentifierPreparer, self).__init__( - dialect, - initial_quote="`", - final_quote="`", - ) - - def format_index(self, index: sa.Index) -> str: - return super().format_index(index).replace("/", "_") - - -class YqlTypeCompiler(StrSQLTypeCompiler): - def visit_JSON(self, type_: Union[sa.JSON, types.YqlJSON], **kw): - return "JSON" - - def visit_CHAR(self, type_: sa.CHAR, **kw): - return "UTF8" - - def visit_VARCHAR(self, type_: sa.VARCHAR, **kw): - return "UTF8" - - def visit_unicode(self, type_: sa.Unicode, **kw): - return "UTF8" - - def visit_uuid(self, type_: sa.Uuid, **kw): - return "UTF8" - - def visit_NVARCHAR(self, type_: sa.NVARCHAR, **kw): - return "UTF8" - - def visit_TEXT(self, type_: sa.TEXT, **kw): - return "UTF8" +from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler - def visit_FLOAT(self, type_: sa.FLOAT, **kw): - return "FLOAT" - - def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw): - return "BOOL" - - def visit_uint64(self, type_: types.UInt64, **kw): - return "UInt64" - - def visit_uint32(self, type_: types.UInt32, **kw): - return "UInt32" - - def visit_uint16(self, type_: types.UInt16, **kw): - return "UInt16" - - def visit_uint8(self, type_: types.UInt8, **kw): - return "UInt8" - - def visit_int64(self, type_: types.Int64, **kw): - return "Int64" - - def visit_int32(self, type_: types.Int32, **kw): - return "Int32" - - def visit_int16(self, type_: types.Int16, **kw): - return "Int16" - - def visit_int8(self, type_: types.Int8, **kw): - return "Int8" - - def visit_INTEGER(self, type_: sa.INTEGER, **kw): - return "Int64" - - def visit_NUMERIC(self, type_: sa.Numeric, **kw): - """Only Decimal(22,9) is supported for table columns""" - return f"Decimal({type_.precision}, {type_.scale})" - - def visit_BINARY(self, type_: sa.BINARY, **kw): - return "String" - - def visit_BLOB(self, type_: sa.BLOB, **kw): - return "String" - - def visit_datetime(self, type_: sa.TIMESTAMP, **kw): - return self.visit_TIMESTAMP(type_, **kw) - - def visit_DATETIME(self, type_: sa.DATETIME, **kw): - return "DateTime" - - def visit_TIMESTAMP(self, type_: sa.TIMESTAMP, **kw): - return "Timestamp" - - def visit_list_type(self, type_: types.ListType, **kw): - inner = self.process(type_.item_type, **kw) - return f"List<{inner}>" - - def visit_ARRAY(self, type_: sa.ARRAY, **kw): - inner = self.process(type_.item_type, **kw) - return f"List<{inner}>" - - def visit_struct_type(self, type_: types.StructType, **kw): - text = "Struct<" - for field, field_type in type_.fields_types: - text += f"{field}:{self.process(field_type, **kw)}" - return text + ">" - - def get_ydb_type( - self, type_: sa.types.TypeEngine, is_optional: bool - ) -> Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]: - if isinstance(type_, sa.TypeDecorator): - type_ = type_.impl - - if isinstance(type_, (sa.Text, sa.String, sa.Uuid)): - ydb_type = ydb.PrimitiveType.Utf8 - - # Integers - elif isinstance(type_, types.UInt64): - ydb_type = ydb.PrimitiveType.Uint64 - elif isinstance(type_, types.UInt32): - ydb_type = ydb.PrimitiveType.Uint32 - elif isinstance(type_, types.UInt16): - ydb_type = ydb.PrimitiveType.Uint16 - elif isinstance(type_, types.UInt8): - ydb_type = ydb.PrimitiveType.Uint8 - elif isinstance(type_, types.Int64): - ydb_type = ydb.PrimitiveType.Int64 - elif isinstance(type_, types.Int32): - ydb_type = ydb.PrimitiveType.Int32 - elif isinstance(type_, types.Int16): - ydb_type = ydb.PrimitiveType.Int16 - elif isinstance(type_, types.Int8): - ydb_type = ydb.PrimitiveType.Int8 - elif isinstance(type_, sa.Integer): - ydb_type = ydb.PrimitiveType.Int64 - # Integers - - # Json - elif isinstance(type_, sa.JSON): - ydb_type = ydb.PrimitiveType.Json - elif isinstance(type_, sa.JSON.JSONStrIndexType): - ydb_type = ydb.PrimitiveType.Utf8 - elif isinstance(type_, sa.JSON.JSONIntIndexType): - ydb_type = ydb.PrimitiveType.Int64 - elif isinstance(type_, sa.JSON.JSONPathType): - ydb_type = ydb.PrimitiveType.Utf8 - elif isinstance(type_, types.YqlJSON): - ydb_type = ydb.PrimitiveType.Json - elif isinstance(type_, types.YqlJSON.YqlJSONPathType): - ydb_type = ydb.PrimitiveType.Utf8 - # Json - elif isinstance(type_, sa.DATETIME): - ydb_type = ydb.PrimitiveType.Datetime - elif isinstance(type_, sa.TIMESTAMP): - ydb_type = ydb.PrimitiveType.Timestamp - elif isinstance(type_, sa.DateTime): - ydb_type = ydb.PrimitiveType.Timestamp - elif isinstance(type_, sa.Date): - ydb_type = ydb.PrimitiveType.Date - elif isinstance(type_, sa.BINARY): - ydb_type = ydb.PrimitiveType.String - elif isinstance(type_, sa.Float): - ydb_type = ydb.PrimitiveType.Float - elif isinstance(type_, sa.Double): - ydb_type = ydb.PrimitiveType.Double - elif isinstance(type_, sa.Boolean): - ydb_type = ydb.PrimitiveType.Bool - elif isinstance(type_, sa.Numeric): - ydb_type = ydb.DecimalType(type_.precision, type_.scale) - elif isinstance(type_, (types.ListType, sa.ARRAY)): - ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False)) - elif isinstance(type_, types.StructType): - ydb_type = ydb.StructType() - for field, field_type in type_.fields_types.items(): - ydb_type.add_member(field, self.get_ydb_type(field_type(), is_optional=False)) - else: - raise dbapi.NotSupportedError(f"{type_} bind variables not supported") - - if is_optional: - return ydb.OptionalType(ydb_type) - - return ydb_type +from . import types class ParametrizedFunction(functions.Function): @@ -242,300 +36,6 @@ def __init__(self, name, params, *args, **kwargs): self.params_expr = ClauseList(operator=functions.operators.comma_op, group_contents=True, *params).self_group() -class YqlCompiler(StrSQLCompiler): - compound_keywords = COMPOUND_KEYWORDS - - def get_from_hint_text(self, table, text): - return text - - def render_bind_cast(self, type_, dbapi_type, sqltext): - pass - - def group_by_clause(self, select, **kw): - # Hack to ensure it is possible to define labels in groupby. - kw.update(within_columns_clause=True) - return super(YqlCompiler, self).group_by_clause(select, **kw) - - def limit_clause(self, select, **kw): - text = "" - if select._limit_clause is not None: - limit_clause = self._maybe_cast( - select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) - ) - text += "\n LIMIT " + self.process(limit_clause, **kw) - if select._offset_clause is not None: - offset_clause = self._maybe_cast( - select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) - ) - if select._limit_clause is None: - text += "\n LIMIT 1000" # For some reason, YDB do not support LIMIT NULL OFFSET - text += " OFFSET " + self.process(offset_clause, **kw) - return text - - def _maybe_cast( - self, - element: Any, - cast_to: Type[sa.types.TypeEngine], - skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, - ) -> Any: - if not skip_types: - skip_types = (cast_to,) - if cast_to not in skip_types: - skip_types = (*skip_types, cast_to) - if not hasattr(element, "type") or not isinstance(element.type, skip_types): - return sa.Cast(element, cast_to) - return element - - def render_literal_value(self, value, type_): - if isinstance(value, str): - value = "".join(STR_QUOTE_MAP.get(x, x) for x in value) - return f"'{value}'" - return super().render_literal_value(value, type_) - - def visit_lambda(self, lambda_, **kw): - func = lambda_.func - spec = inspect_getfullargspec(func) - - if spec.varargs: - raise CompileError("Lambdas with *args are not supported") - if spec.varkw: - raise CompileError("Lambdas with **kwargs are not supported") - - args = [literal_column("$" + arg) for arg in spec.args] - text = f'({", ".join("$" + arg for arg in spec.args)}) -> ' f"{{ RETURN {self.process(func(*args), **kw)} ;}}" - - return text - - def visit_parametrized_function(self, func, **kwargs): - name = func.name - name_parts = [] - for name in name.split("::"): - fname = ( - self.preparer.quote(name) - if self.preparer._requires_quotes_illegal_chars(name) or isinstance(name, sa.sql.elements.quoted_name) - else name - ) - - name_parts.append(fname) - - name = "::".join(name_parts) - params = func.params_expr._compiler_dispatch(self, **kwargs) - args = self.function_argspec(func, **kwargs) - return "%(name)s%(params)s%(args)s" % dict(name=name, params=params, args=args) - - def visit_function(self, func, add_to_result_map=None, **kwargs): - # Copypaste of `sa.sql.compiler.SQLCompiler.visit_function` with - # `::` as namespace separator instead of `.` - if add_to_result_map: - add_to_result_map(func.name, func.name, (), func.type) - - disp = getattr(self, f"visit_{func.name.lower()}_func", None) - if disp: - return disp(func, **kwargs) - - name = sa.sql.compiler.FUNCTIONS.get(func.__class__) - if name: - if func._has_args: - name += "%(expr)s" - else: - name = func.name - name = ( - self.preparer.quote(name) - if self.preparer._requires_quotes_illegal_chars(name) or isinstance(name, sa.sql.elements.quoted_name) - else name - ) - name += "%(expr)s" - - return "::".join( - [ - ( - self.preparer.quote(tok) - if self.preparer._requires_quotes_illegal_chars(tok) - or isinstance(name, sa.sql.elements.quoted_name) - else tok - ) - for tok in func.packagenames - ] - + [name] - ) % {"expr": self.function_argspec(func, **kwargs)} - - def _yson_convert_to(self, statement: str, target_type: sa.types.TypeEngine) -> str: - type_name = target_type.compile(self.dialect) - if isinstance(target_type, sa.Numeric) and not isinstance(target_type, (sa.Float, sa.Double)): - # Since Decimal is stored in JSON either as String or as Float - string_value = f"Yson::ConvertTo({statement}, Optional, Yson::Options(true AS AutoConvert))" - return f"CAST({string_value} AS Optional<{type_name}>)" - return f"Yson::ConvertTo({statement}, Optional<{type_name}>)" - - def visit_json_getitem_op_binary(self, binary: sa.BinaryExpression, operator, **kw) -> str: - json_field = self.process(binary.left, **kw) - index = self.process(binary.right, **kw) - return self._yson_convert_to(f"{json_field}[{index}]", binary.type) - - def visit_json_path_getitem_op_binary(self, binary: sa.BinaryExpression, operator, **kw) -> str: - json_field = self.process(binary.left, **kw) - path = self.process(binary.right, **kw) - return self._yson_convert_to(f"Yson::YPath({json_field}, {path})", binary.type) - - def visit_regexp_match_op_binary(self, binary, operator, **kw): - return self._generate_generic_binary(binary, " REGEXP ", **kw) - - def visit_not_regexp_match_op_binary(self, binary, operator, **kw): - return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) - - def _is_bound_to_nullable_column(self, bind_name: str) -> bool: - if bind_name in self.column_keys and hasattr(self.compile_state, "dml_table"): - if bind_name in self.compile_state.dml_table.c: - column = self.compile_state.dml_table.c[bind_name] - return column.nullable and not column.primary_key - return False - - def _guess_bound_variable_type_by_parameters( - self, bind: sa.BindParameter, post_compile_bind_values: list - ) -> Optional[sa.types.TypeEngine]: - bind_type = bind.type - if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values): - not_null_values = [v for v in post_compile_bind_values if v is not None] - if not_null_values: - bind_type = sa.BindParameter("", not_null_values[0]).type - - if isinstance(bind_type, sa.types.NullType): - return None - - return bind_type - - def _get_expanding_bind_names(self, bind_name: str, parameters_values: Mapping[str, List[Any]]) -> List[Any]: - expanding_bind_names = [] - for parameter_name in parameters_values: - parameter_bind_name = "_".join(parameter_name.split("_")[:-1]) - if parameter_bind_name == bind_name: - expanding_bind_names.append(parameter_name) - return expanding_bind_names - - def get_bind_types( - self, post_compile_parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] - ) -> Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]]: - """ - This method extracts information about bound variables from the table definition and parameters. - """ - if isinstance(post_compile_parameters, collections.abc.Mapping): - post_compile_parameters = [post_compile_parameters] - - parameters_values = collections.defaultdict(list) - for parameters_entry in post_compile_parameters: - for parameter_name, parameter_value in parameters_entry.items(): - parameters_values[parameter_name].append(parameter_value) - - parameter_types = {} - for bind_name in self.bind_names.values(): - bind = self.binds[bind_name] - - if bind.literal_execute: - continue - - if not bind.expanding: - post_compile_bind_names = [bind_name] - post_compile_bind_values = parameters_values[bind_name] - else: - post_compile_bind_names = self._get_expanding_bind_names(bind_name, parameters_values) - post_compile_bind_values = [] - for parameter_name, parameter_values in parameters_values.items(): - if parameter_name in post_compile_bind_names: - post_compile_bind_values.extend(parameter_values) - - is_optional = self._is_bound_to_nullable_column(bind_name) - if not post_compile_bind_values or None in post_compile_bind_values: - is_optional = True - - bind_type = self._guess_bound_variable_type_by_parameters(bind, post_compile_bind_values) - - if bind_type: - for post_compile_bind_name in post_compile_bind_names: - parameter_types[post_compile_bind_name] = YqlTypeCompiler(self.dialect).get_ydb_type( - bind_type, is_optional - ) - - return parameter_types - - def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw): - return self.visit_insert(insert_stmt, visited_bindparam, **kw).replace("INSERT", "UPSERT", 1) - - -class YqlDDLCompiler(DDLCompiler): - def visit_create_index(self, create: ddl.CreateIndex, **kw) -> str: - index: sa.Index = create.element - ydb_opts = index.dialect_options.get("ydb", {}) - - self._verify_index_table(index) - - if index.name is None: - raise CompileError("ADD INDEX requires that the index has a name") - - table_name = self.preparer.format_table(index.table) - index_name = self._prepared_index_name(index) - - text = f"ALTER TABLE {table_name} ADD INDEX {index_name} GLOBAL" - - text += " SYNC" if not ydb_opts.get("async", False) else " ASYNC" - - columns = [self.preparer.format_column(col) for col in index.columns.values()] - cover_columns = [ - col if isinstance(col, str) else self.preparer.format_column(col) for col in ydb_opts.get("cover", []) - ] - cover_columns = list(dict.fromkeys(cover_columns)) # dict preserves order - - text += " ON (" + ", ".join(columns) + ")" - - if cover_columns: - text += " COVER (" + ", ".join(cover_columns) + ")" - - return text - - def visit_drop_index(self, drop: ddl.DropIndex, **kw) -> str: - index: sa.Index = drop.element - - self._verify_index_table(index) - - table_name = self.preparer.format_table(index.table) - index_name = self._prepared_index_name(index) - - return f"ALTER TABLE {table_name} DROP INDEX {index_name}" - - def post_create_table(self, table: sa.Table) -> str: - ydb_opts = table.dialect_options["ydb"] - with_clause_list = self._render_table_partitioning_settings(ydb_opts) - if with_clause_list: - with_clause_text = ",\n".join(with_clause_list) - return f"\nWITH (\n\t{with_clause_text}\n)" - return "" - - def _render_table_partitioning_settings(self, ydb_opts: Dict[str, Any]) -> List[str]: - table_partitioning_settings = [] - if ydb_opts["auto_partitioning_by_size"] is not None: - auto_partitioning_by_size = "ENABLED" if ydb_opts["auto_partitioning_by_size"] else "DISABLED" - table_partitioning_settings.append(f"AUTO_PARTITIONING_BY_SIZE = {auto_partitioning_by_size}") - if ydb_opts["auto_partitioning_by_load"] is not None: - auto_partitioning_by_load = "ENABLED" if ydb_opts["auto_partitioning_by_load"] else "DISABLED" - table_partitioning_settings.append(f"AUTO_PARTITIONING_BY_LOAD = {auto_partitioning_by_load}") - if ydb_opts["auto_partitioning_partition_size_mb"] is not None: - table_partitioning_settings.append( - f"AUTO_PARTITIONING_PARTITION_SIZE_MB = {ydb_opts['auto_partitioning_partition_size_mb']}" - ) - if ydb_opts["auto_partitioning_min_partitions_count"] is not None: - table_partitioning_settings.append( - f"AUTO_PARTITIONING_MIN_PARTITIONS_COUNT = {ydb_opts['auto_partitioning_min_partitions_count']}" - ) - if ydb_opts["auto_partitioning_max_partitions_count"] is not None: - table_partitioning_settings.append( - f"AUTO_PARTITIONING_MAX_PARTITIONS_COUNT = {ydb_opts['auto_partitioning_max_partitions_count']}" - ) - if ydb_opts["uniform_partitions"] is not None: - table_partitioning_settings.append(f"UNIFORM_PARTITIONS = {ydb_opts['uniform_partitions']}") - if ydb_opts["partition_at_keys"] is not None: - table_partitioning_settings.append(f"PARTITION_AT_KEYS = {ydb_opts['partition_at_keys']}") - return table_partitioning_settings - - def upsert(table): return Upsert(table) @@ -579,15 +79,17 @@ def _get_column_info(t): class YdbRequestSettingsCharacteristic(characteristics.ConnectionCharacteristic): - def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> None: + def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection) -> None: dialect.reset_ydb_request_settings(dbapi_connection) def set_characteristic( - self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection, value: ydb.BaseRequestSettings + self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings ) -> None: dialect.set_ydb_request_settings(dbapi_connection, value) - def get_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> ydb.BaseRequestSettings: + def get_characteristic( + self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection + ) -> ydb.BaseRequestSettings: return dialect.get_ydb_request_settings(dbapi_connection) @@ -667,7 +169,11 @@ class YqlDialect(StrCompileDialect): @classmethod def import_dbapi(cls: Any): - return dbapi + return ydb_dbapi + + @classmethod + def dbapi(cls): + return cls.import_dbapi() def __init__( self, @@ -686,13 +192,13 @@ def __init__( def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription: if schema is not None: - raise dbapi.NotSupportedError("unsupported on non empty schema") + raise ydb_dbapi.NotSupportedError("unsupported on non empty schema") qt = table_name if isinstance(table_name, str) else table_name.name raw_conn = connection.connection try: return raw_conn.describe(qt) - except dbapi.DatabaseError as e: + except ydb_dbapi.DatabaseError as e: raise NoSuchTableError(qt) from e def get_view_names(self, connection, schema=None, **kw: Any): @@ -716,9 +222,9 @@ def get_columns(self, connection, table_name, schema=None, **kw): return as_compatible @reflection.cache - def get_table_names(self, connection, schema=None, **kw) -> List[str]: + def get_table_names(self, connection, schema=None, **kw): if schema: - raise dbapi.NotSupportedError("unsupported on non empty schema") + raise ydb_dbapi.NotSupportedError("unsupported on non empty schema") raw_conn = connection.connection return raw_conn.get_table_names() @@ -760,38 +266,38 @@ def get_indexes(self, connection, table_name, schema=None, **kwargs): ) return sa_indexes - def set_isolation_level(self, dbapi_connection: dbapi.Connection, level: str) -> None: + def set_isolation_level(self, dbapi_connection: ydb_dbapi.Connection, level: str) -> None: dbapi_connection.set_isolation_level(level) - def get_default_isolation_level(self, dbapi_conn: dbapi.Connection) -> str: - return dbapi.IsolationLevel.AUTOCOMMIT + def get_default_isolation_level(self, dbapi_conn: ydb_dbapi.Connection) -> str: + return ydb_dbapi.IsolationLevel.AUTOCOMMIT - def get_isolation_level(self, dbapi_connection: dbapi.Connection) -> str: + def get_isolation_level(self, dbapi_connection: ydb_dbapi.Connection) -> str: return dbapi_connection.get_isolation_level() def set_ydb_request_settings( self, - dbapi_connection: dbapi.Connection, + dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings, ) -> None: dbapi_connection.set_ydb_request_settings(value) - def reset_ydb_request_settings(self, dbapi_connection: dbapi.Connection): + def reset_ydb_request_settings(self, dbapi_connection: ydb_dbapi.Connection): self.set_ydb_request_settings(dbapi_connection, ydb.BaseRequestSettings()) - def get_ydb_request_settings(self, dbapi_connection: dbapi.Connection) -> ydb.BaseRequestSettings: + def get_ydb_request_settings(self, dbapi_connection: ydb_dbapi.Connection) -> ydb.BaseRequestSettings: return dbapi_connection.get_ydb_request_settings() def connect(self, *cargs, **cparams): - return self.loaded_dbapi.connect(*cargs, **cparams) + return self.dbapi.connect(*cargs, **cparams) - def do_begin(self, dbapi_connection: dbapi.Connection) -> None: + def do_begin(self, dbapi_connection: ydb_dbapi.Connection) -> None: dbapi_connection.begin() - def do_rollback(self, dbapi_connection: dbapi.Connection) -> None: + def do_rollback(self, dbapi_connection: ydb_dbapi.Connection) -> None: dbapi_connection.rollback() - def do_commit(self, dbapi_connection: dbapi.Connection) -> None: + def do_commit(self, dbapi_connection: ydb_dbapi.Connection) -> None: dbapi_connection.commit() def _handle_column_name(self, variable): @@ -873,7 +379,7 @@ def _prepare_ydb_query( statement, parameters = self._format_variables(statement, parameters, execute_many) return statement, parameters - def do_ping(self, dbapi_connection: dbapi.Connection) -> bool: + def do_ping(self, dbapi_connection: ydb_dbapi.Connection) -> bool: cursor = dbapi_connection.cursor() statement, _ = self._prepare_ydb_query(self._dialect_specific_select_one) try: @@ -884,7 +390,7 @@ def do_ping(self, dbapi_connection: dbapi.Connection) -> bool: def do_executemany( self, - cursor: dbapi.Cursor, + cursor: ydb_dbapi.Cursor, statement: str, parameters: Optional[Sequence[Mapping[str, Any]]], context: Optional[DefaultExecutionContext] = None, @@ -894,7 +400,7 @@ def do_executemany( def do_execute( self, - cursor: dbapi.Cursor, + cursor: ydb_dbapi.Cursor, statement: str, parameters: Optional[Mapping[str, Any]] = None, context: Optional[DefaultExecutionContext] = None, @@ -913,4 +419,4 @@ class AsyncYqlDialect(YqlDialect): supports_statement_cache = True def connect(self, *cargs, **cparams): - return AdaptedAsyncConnection(util.await_only(self.loaded_dbapi.async_connect(*cargs, **cparams))) + return AdaptedAsyncConnection(util.await_only(self.dbapi.async_connect(*cargs, **cparams))) diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py b/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py new file mode 100644 index 0000000..31affdd --- /dev/null +++ b/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py @@ -0,0 +1,16 @@ +import sqlalchemy as sa + +sa_version = sa.__version__ + +if sa_version.startswith("2."): + from .sa20 import YqlCompiler + from .sa20 import YqlDDLCompiler + from .sa20 import YqlTypeCompiler + from .sa20 import YqlIdentifierPreparer +elif sa_version.startswith("1.4."): + from .sa14 import YqlCompiler + from .sa14 import YqlDDLCompiler + from .sa14 import YqlTypeCompiler + from .sa14 import YqlIdentifierPreparer +else: + raise RuntimeError("Unsupported SQLAlchemy version.") diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/base.py b/ydb_sqlalchemy/sqlalchemy/compiler/base.py new file mode 100644 index 0000000..9582a06 --- /dev/null +++ b/ydb_sqlalchemy/sqlalchemy/compiler/base.py @@ -0,0 +1,320 @@ +import sqlalchemy as sa +import ydb + +from sqlalchemy.exc import CompileError +from sqlalchemy.sql import ddl +from sqlalchemy.sql.compiler import ( + DDLCompiler, + IdentifierPreparer, + StrSQLCompiler, + StrSQLTypeCompiler, + selectable, +) +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) + +from .. import types + + +STR_QUOTE_MAP = { + "'": "\\'", + "\\": "\\\\", + "\0": "\\0", + "\b": "\\b", + "\f": "\\f", + "\r": "\\r", + "\n": "\\n", + "\t": "\\t", + "%": "%%", +} + + +COMPOUND_KEYWORDS = { + selectable.CompoundSelect.UNION: "UNION ALL", + selectable.CompoundSelect.UNION_ALL: "UNION ALL", + selectable.CompoundSelect.EXCEPT: "EXCEPT", + selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", + selectable.CompoundSelect.INTERSECT: "INTERSECT", + selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", +} + + +class BaseYqlTypeCompiler(StrSQLTypeCompiler): + def visit_JSON(self, type_: Union[sa.JSON, types.YqlJSON], **kw): + return "JSON" + + def visit_CHAR(self, type_: sa.CHAR, **kw): + return "UTF8" + + def visit_VARCHAR(self, type_: sa.VARCHAR, **kw): + return "UTF8" + + def visit_unicode(self, type_: sa.Unicode, **kw): + return "UTF8" + + def visit_NVARCHAR(self, type_: sa.NVARCHAR, **kw): + return "UTF8" + + def visit_TEXT(self, type_: sa.TEXT, **kw): + return "UTF8" + + def visit_FLOAT(self, type_: sa.FLOAT, **kw): + return "FLOAT" + + def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw): + return "BOOL" + + def visit_uint64(self, type_: types.UInt64, **kw): + return "UInt64" + + def visit_uint32(self, type_: types.UInt32, **kw): + return "UInt32" + + def visit_uint16(self, type_: types.UInt16, **kw): + return "UInt16" + + def visit_uint8(self, type_: types.UInt8, **kw): + return "UInt8" + + def visit_int64(self, type_: types.Int64, **kw): + return "Int64" + + def visit_int32(self, type_: types.Int32, **kw): + return "Int32" + + def visit_int16(self, type_: types.Int16, **kw): + return "Int16" + + def visit_int8(self, type_: types.Int8, **kw): + return "Int8" + + def visit_INTEGER(self, type_: sa.INTEGER, **kw): + return "Int64" + + def visit_NUMERIC(self, type_: sa.Numeric, **kw): + """Only Decimal(22,9) is supported for table columns""" + return f"Decimal({type_.precision}, {type_.scale})" + + def visit_BINARY(self, type_: sa.BINARY, **kw): + return "String" + + def visit_BLOB(self, type_: sa.BLOB, **kw): + return "String" + + def visit_datetime(self, type_: sa.TIMESTAMP, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_DATETIME(self, type_: sa.DATETIME, **kw): + return "DateTime" + + def visit_TIMESTAMP(self, type_: sa.TIMESTAMP, **kw): + return "Timestamp" + + def visit_list_type(self, type_: types.ListType, **kw): + inner = self.process(type_.item_type, **kw) + return f"List<{inner}>" + + def visit_ARRAY(self, type_: sa.ARRAY, **kw): + inner = self.process(type_.item_type, **kw) + return f"List<{inner}>" + + def visit_struct_type(self, type_: types.StructType, **kw): + text = "Struct<" + for field, field_type in type_.fields_types: + text += f"{field}:{self.process(field_type, **kw)}" + return text + ">" + + def get_ydb_type( + self, type_: sa.types.TypeEngine, is_optional: bool + ) -> Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]: + raise NotImplementedError() + + +class BaseYqlCompiler(StrSQLCompiler): + compound_keywords = COMPOUND_KEYWORDS + _type_compiler_cls = BaseYqlTypeCompiler + + def get_from_hint_text(self, table, text): + return text + + def group_by_clause(self, select, **kw): + # Hack to ensure it is possible to define labels in groupby. + kw.update(within_columns_clause=True) + return super(BaseYqlCompiler, self).group_by_clause(select, **kw) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + limit_clause = self._maybe_cast( + select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) + ) + text += "\n LIMIT " + self.process(limit_clause, **kw) + if select._offset_clause is not None: + offset_clause = self._maybe_cast( + select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) + ) + if select._limit_clause is None: + text += "\n LIMIT 1000" # For some reason, YDB do not support LIMIT NULL OFFSET + text += " OFFSET " + self.process(offset_clause, **kw) + return text + + def _maybe_cast( + self, + element: Any, + cast_to: Type[sa.types.TypeEngine], + skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, + ) -> Any: + raise NotImplementedError() + + def render_literal_value(self, value, type_): + if isinstance(value, str): + value = "".join(STR_QUOTE_MAP.get(x, x) for x in value) + return f"'{value}'" + return super().render_literal_value(value, type_) + + def visit_parametrized_function(self, func, **kwargs): + name = func.name + name_parts = [] + for name in name.split("::"): + fname = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + + name_parts.append(fname) + + name = "::".join(name_parts) + params = func.params_expr._compiler_dispatch(self, **kwargs) + args = self.function_argspec(func, **kwargs) + return "%(name)s%(params)s%(args)s" % dict(name=name, params=params, args=args) + + def visit_function(self, func, add_to_result_map=None, **kwargs): + # Copypaste of `sa.sql.compiler.SQLCompiler.visit_function` with + # `::` as namespace separator instead of `.` + if add_to_result_map: + add_to_result_map(func.name, func.name, (), func.type) + + disp = getattr(self, f"visit_{func.name.lower()}_func", None) + if disp: + return disp(func, **kwargs) + + name = sa.sql.compiler.FUNCTIONS.get(func.__class__) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + name += "%(expr)s" + + return "::".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + or isinstance(name, sa.sql.elements.quoted_name) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} + + +class BaseYqlDDLCompiler(DDLCompiler): + def visit_create_index(self, create: ddl.CreateIndex, **kw) -> str: + index: sa.Index = create.element + ydb_opts = index.dialect_options.get("ydb", {}) + + self._verify_index_table(index) + + if index.name is None: + raise CompileError("ADD INDEX requires that the index has a name") + + table_name = self.preparer.format_table(index.table) + index_name = self._prepared_index_name(index) + + text = f"ALTER TABLE {table_name} ADD INDEX {index_name} GLOBAL" + + text += " SYNC" if not ydb_opts.get("async", False) else " ASYNC" + + columns = [self.preparer.format_column(col) for col in index.columns.values()] + cover_columns = [ + col if isinstance(col, str) else self.preparer.format_column(col) for col in ydb_opts.get("cover", []) + ] + cover_columns = list(dict.fromkeys(cover_columns)) # dict preserves order + + text += " ON (" + ", ".join(columns) + ")" + + if cover_columns: + text += " COVER (" + ", ".join(cover_columns) + ")" + + return text + + def visit_drop_index(self, drop: ddl.DropIndex, **kw) -> str: + index: sa.Index = drop.element + + self._verify_index_table(index) + + table_name = self.preparer.format_table(index.table) + index_name = self._prepared_index_name(index) + + return f"ALTER TABLE {table_name} DROP INDEX {index_name}" + + def post_create_table(self, table: sa.Table) -> str: + ydb_opts = table.dialect_options["ydb"] + with_clause_list = self._render_table_partitioning_settings(ydb_opts) + if with_clause_list: + with_clause_text = ",\n".join(with_clause_list) + return f"\nWITH (\n\t{with_clause_text}\n)" + return "" + + def _render_table_partitioning_settings(self, ydb_opts: Dict[str, Any]) -> List[str]: + table_partitioning_settings = [] + if ydb_opts["auto_partitioning_by_size"] is not None: + auto_partitioning_by_size = "ENABLED" if ydb_opts["auto_partitioning_by_size"] else "DISABLED" + table_partitioning_settings.append(f"AUTO_PARTITIONING_BY_SIZE = {auto_partitioning_by_size}") + if ydb_opts["auto_partitioning_by_load"] is not None: + auto_partitioning_by_load = "ENABLED" if ydb_opts["auto_partitioning_by_load"] else "DISABLED" + table_partitioning_settings.append(f"AUTO_PARTITIONING_BY_LOAD = {auto_partitioning_by_load}") + if ydb_opts["auto_partitioning_partition_size_mb"] is not None: + table_partitioning_settings.append( + f"AUTO_PARTITIONING_PARTITION_SIZE_MB = {ydb_opts['auto_partitioning_partition_size_mb']}" + ) + if ydb_opts["auto_partitioning_min_partitions_count"] is not None: + table_partitioning_settings.append( + f"AUTO_PARTITIONING_MIN_PARTITIONS_COUNT = {ydb_opts['auto_partitioning_min_partitions_count']}" + ) + if ydb_opts["auto_partitioning_max_partitions_count"] is not None: + table_partitioning_settings.append( + f"AUTO_PARTITIONING_MAX_PARTITIONS_COUNT = {ydb_opts['auto_partitioning_max_partitions_count']}" + ) + if ydb_opts["uniform_partitions"] is not None: + table_partitioning_settings.append(f"UNIFORM_PARTITIONS = {ydb_opts['uniform_partitions']}") + if ydb_opts["partition_at_keys"] is not None: + table_partitioning_settings.append(f"PARTITION_AT_KEYS = {ydb_opts['partition_at_keys']}") + return table_partitioning_settings + + +class BaseYqlIdentifierPreparer(IdentifierPreparer): + def __init__(self, dialect): + super(BaseYqlIdentifierPreparer, self).__init__( + dialect, + initial_quote="`", + final_quote="`", + ) + + def format_index(self, index: sa.Index) -> str: + return super().format_index(index).replace("/", "_") diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py new file mode 100644 index 0000000..56febf3 --- /dev/null +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py @@ -0,0 +1,208 @@ +import collections +import sqlalchemy as sa +import ydb +import ydb_dbapi as dbapi + + +from .base import ( + BaseYqlCompiler, + BaseYqlDDLCompiler, + BaseYqlIdentifierPreparer, + BaseYqlTypeCompiler, +) +from typing import ( + Any, + Dict, + List, + Mapping, + Sequence, + Optional, + Tuple, + Type, + Union, +) +from .. import types + + +class YqlTypeCompiler(BaseYqlTypeCompiler): + def get_ydb_type( + self, type_: sa.types.TypeEngine, is_optional: bool + ) -> Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]: + if isinstance(type_, sa.TypeDecorator): + type_ = type_.impl + + if isinstance(type_, (sa.Text, sa.String)): + ydb_type = ydb.PrimitiveType.Utf8 + + # Integers + elif isinstance(type_, types.UInt64): + ydb_type = ydb.PrimitiveType.Uint64 + elif isinstance(type_, types.UInt32): + ydb_type = ydb.PrimitiveType.Uint32 + elif isinstance(type_, types.UInt16): + ydb_type = ydb.PrimitiveType.Uint16 + elif isinstance(type_, types.UInt8): + ydb_type = ydb.PrimitiveType.Uint8 + elif isinstance(type_, types.Int64): + ydb_type = ydb.PrimitiveType.Int64 + elif isinstance(type_, types.Int32): + ydb_type = ydb.PrimitiveType.Int32 + elif isinstance(type_, types.Int16): + ydb_type = ydb.PrimitiveType.Int16 + elif isinstance(type_, types.Int8): + ydb_type = ydb.PrimitiveType.Int8 + elif isinstance(type_, sa.Integer): + ydb_type = ydb.PrimitiveType.Int64 + # Integers + + # Json + elif isinstance(type_, sa.JSON): + ydb_type = ydb.PrimitiveType.Json + elif isinstance(type_, sa.JSON.JSONStrIndexType): + ydb_type = ydb.PrimitiveType.Utf8 + elif isinstance(type_, sa.JSON.JSONIntIndexType): + ydb_type = ydb.PrimitiveType.Int64 + elif isinstance(type_, sa.JSON.JSONPathType): + ydb_type = ydb.PrimitiveType.Utf8 + elif isinstance(type_, types.YqlJSON): + ydb_type = ydb.PrimitiveType.Json + elif isinstance(type_, types.YqlJSON.YqlJSONPathType): + ydb_type = ydb.PrimitiveType.Utf8 + # Json + elif isinstance(type_, sa.DATETIME): + ydb_type = ydb.PrimitiveType.Datetime + elif isinstance(type_, sa.TIMESTAMP): + ydb_type = ydb.PrimitiveType.Timestamp + elif isinstance(type_, sa.DateTime): + ydb_type = ydb.PrimitiveType.Timestamp + elif isinstance(type_, sa.Date): + ydb_type = ydb.PrimitiveType.Date + elif isinstance(type_, sa.BINARY): + ydb_type = ydb.PrimitiveType.String + elif isinstance(type_, sa.Float): + ydb_type = ydb.PrimitiveType.Float + elif isinstance(type_, sa.Double): + ydb_type = ydb.PrimitiveType.Double + elif isinstance(type_, sa.Boolean): + ydb_type = ydb.PrimitiveType.Bool + elif isinstance(type_, sa.Numeric): + ydb_type = ydb.DecimalType(type_.precision, type_.scale) + elif isinstance(type_, (types.ListType, sa.ARRAY)): + ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False)) + elif isinstance(type_, types.StructType): + ydb_type = ydb.StructType() + for field, field_type in type_.fields_types.items(): + ydb_type.add_member(field, self.get_ydb_type(field_type(), is_optional=False)) + else: + raise dbapi.NotSupportedError(f"{type_} bind variables not supported") + + if is_optional: + return ydb.OptionalType(ydb_type) + + return ydb_type + + def _maybe_cast( + self, + element: Any, + cast_to: Type[sa.types.TypeEngine], + skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, + ) -> Any: + if not skip_types: + skip_types = (cast_to,) + if cast_to not in skip_types: + skip_types = (*skip_types, cast_to) + if not hasattr(element, "type") or not isinstance(element.type, skip_types): + return sa.cast(element, cast_to) + return element + + +class YqlIdentifierPreparer(BaseYqlIdentifierPreparer): + ... + + +class YqlCompiler(BaseYqlCompiler): + _type_compiler_cls = YqlTypeCompiler + + def visit_upsert(self, insert_stmt, **kw): + return self.visit_insert(insert_stmt, **kw).replace("INSERT", "UPSERT", 1) + + def _is_bound_to_nullable_column(self, bind_name: str) -> bool: + if bind_name in self.column_keys and hasattr(self.compile_state, "dml_table"): + if bind_name in self.compile_state.dml_table.c: + column = self.compile_state.dml_table.c[bind_name] + return column.nullable and not column.primary_key + return False + + def _guess_bound_variable_type_by_parameters( + self, bind, post_compile_bind_values: list + ) -> Optional[sa.types.TypeEngine]: + bind_type = bind.type + if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values): + not_null_values = [v for v in post_compile_bind_values if v is not None] + if not_null_values: + bind_type = sa.bindparam("", not_null_values[0]).type + + if isinstance(bind_type, sa.types.NullType): + return None + + return bind_type + + def _get_expanding_bind_names(self, bind_name: str, parameters_values: Mapping[str, List[Any]]) -> List[Any]: + expanding_bind_names = [] + for parameter_name in parameters_values: + parameter_bind_name = "_".join(parameter_name.split("_")[:-1]) + if parameter_bind_name == bind_name: + expanding_bind_names.append(parameter_name) + return expanding_bind_names + + def render_bind_cast(self, type_, dbapi_type, sqltext): + pass + + def get_bind_types( + self, post_compile_parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] + ) -> Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]]: + """ + This method extracts information about bound variables from the table definition and parameters. + """ + if isinstance(post_compile_parameters, collections.abc.Mapping): + post_compile_parameters = [post_compile_parameters] + + parameters_values = collections.defaultdict(list) + for parameters_entry in post_compile_parameters: + for parameter_name, parameter_value in parameters_entry.items(): + parameters_values[parameter_name].append(parameter_value) + + parameter_types = {} + for bind_name in self.bind_names.values(): + bind = self.binds[bind_name] + + if bind.literal_execute: + continue + + if not bind.expanding: + post_compile_bind_names = [bind_name] + post_compile_bind_values = parameters_values[bind_name] + else: + post_compile_bind_names = self._get_expanding_bind_names(bind_name, parameters_values) + post_compile_bind_values = [] + for parameter_name, parameter_values in parameters_values.items(): + if parameter_name in post_compile_bind_names: + post_compile_bind_values.extend(parameter_values) + + is_optional = self._is_bound_to_nullable_column(bind_name) + if not post_compile_bind_values or None in post_compile_bind_values: + is_optional = True + + bind_type = self._guess_bound_variable_type_by_parameters(bind, post_compile_bind_values) + + if bind_type: + for post_compile_bind_name in post_compile_bind_names: + parameter_types[post_compile_bind_name] = self._type_compiler_cls(self.dialect).get_ydb_type( + bind_type, is_optional + ) + + return parameter_types + + +class YqlDDLCompiler(BaseYqlDDLCompiler): + ... diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py new file mode 100644 index 0000000..45c674d --- /dev/null +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py @@ -0,0 +1,252 @@ +import collections +import sqlalchemy as sa +import ydb +import ydb_dbapi as dbapi + +from sqlalchemy.exc import CompileError +from sqlalchemy.sql import literal_column +from sqlalchemy.util.compat import inspect_getfullargspec + +from .base import ( + BaseYqlCompiler, + BaseYqlDDLCompiler, + BaseYqlIdentifierPreparer, + BaseYqlTypeCompiler, +) +from typing import ( + Any, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) +from .. import types + + +class YqlTypeCompiler(BaseYqlTypeCompiler): + def visit_uuid(self, type_: sa.Uuid, **kw): + return "UTF8" + + def get_ydb_type( + self, type_: sa.types.TypeEngine, is_optional: bool + ) -> Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]: + if isinstance(type_, sa.TypeDecorator): + type_ = type_.impl + + if isinstance(type_, (sa.Text, sa.String, sa.Uuid)): + ydb_type = ydb.PrimitiveType.Utf8 + + # Integers + elif isinstance(type_, types.UInt64): + ydb_type = ydb.PrimitiveType.Uint64 + elif isinstance(type_, types.UInt32): + ydb_type = ydb.PrimitiveType.Uint32 + elif isinstance(type_, types.UInt16): + ydb_type = ydb.PrimitiveType.Uint16 + elif isinstance(type_, types.UInt8): + ydb_type = ydb.PrimitiveType.Uint8 + elif isinstance(type_, types.Int64): + ydb_type = ydb.PrimitiveType.Int64 + elif isinstance(type_, types.Int32): + ydb_type = ydb.PrimitiveType.Int32 + elif isinstance(type_, types.Int16): + ydb_type = ydb.PrimitiveType.Int16 + elif isinstance(type_, types.Int8): + ydb_type = ydb.PrimitiveType.Int8 + elif isinstance(type_, sa.Integer): + ydb_type = ydb.PrimitiveType.Int64 + # Integers + + # Json + elif isinstance(type_, sa.JSON): + ydb_type = ydb.PrimitiveType.Json + elif isinstance(type_, sa.JSON.JSONStrIndexType): + ydb_type = ydb.PrimitiveType.Utf8 + elif isinstance(type_, sa.JSON.JSONIntIndexType): + ydb_type = ydb.PrimitiveType.Int64 + elif isinstance(type_, sa.JSON.JSONPathType): + ydb_type = ydb.PrimitiveType.Utf8 + elif isinstance(type_, types.YqlJSON): + ydb_type = ydb.PrimitiveType.Json + elif isinstance(type_, types.YqlJSON.YqlJSONPathType): + ydb_type = ydb.PrimitiveType.Utf8 + # Json + elif isinstance(type_, sa.DATETIME): + ydb_type = ydb.PrimitiveType.Datetime + elif isinstance(type_, sa.TIMESTAMP): + ydb_type = ydb.PrimitiveType.Timestamp + elif isinstance(type_, sa.DateTime): + ydb_type = ydb.PrimitiveType.Timestamp + elif isinstance(type_, sa.Date): + ydb_type = ydb.PrimitiveType.Date + elif isinstance(type_, sa.BINARY): + ydb_type = ydb.PrimitiveType.String + elif isinstance(type_, sa.Float): + ydb_type = ydb.PrimitiveType.Float + elif isinstance(type_, sa.Double): + ydb_type = ydb.PrimitiveType.Double + elif isinstance(type_, sa.Boolean): + ydb_type = ydb.PrimitiveType.Bool + elif isinstance(type_, sa.Numeric): + ydb_type = ydb.DecimalType(type_.precision, type_.scale) + elif isinstance(type_, (types.ListType, sa.ARRAY)): + ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False)) + elif isinstance(type_, types.StructType): + ydb_type = ydb.StructType() + for field, field_type in type_.fields_types.items(): + ydb_type.add_member(field, self.get_ydb_type(field_type(), is_optional=False)) + else: + raise dbapi.NotSupportedError(f"{type_} bind variables not supported") + + if is_optional: + return ydb.OptionalType(ydb_type) + + return ydb_type + + +class YqlIdentifierPreparer(BaseYqlIdentifierPreparer): + ... + + +class YqlCompiler(BaseYqlCompiler): + _type_compiler_cls = YqlTypeCompiler + + def visit_json_getitem_op_binary(self, binary: sa.BinaryExpression, operator, **kw) -> str: + json_field = self.process(binary.left, **kw) + index = self.process(binary.right, **kw) + return self._yson_convert_to(f"{json_field}[{index}]", binary.type) + + def visit_json_path_getitem_op_binary(self, binary: sa.BinaryExpression, operator, **kw) -> str: + json_field = self.process(binary.left, **kw) + path = self.process(binary.right, **kw) + return self._yson_convert_to(f"Yson::YPath({json_field}, {path})", binary.type) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " REGEXP ", **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) + + def visit_lambda(self, lambda_, **kw): + func = lambda_.func + spec = inspect_getfullargspec(func) + + if spec.varargs: + raise CompileError("Lambdas with *args are not supported") + if spec.varkw: + raise CompileError("Lambdas with **kwargs are not supported") + + args = [literal_column("$" + arg) for arg in spec.args] + text = f'({", ".join("$" + arg for arg in spec.args)}) -> ' f"{{ RETURN {self.process(func(*args), **kw)} ;}}" + + return text + + def _yson_convert_to(self, statement: str, target_type: sa.types.TypeEngine) -> str: + type_name = target_type.compile(self.dialect) + if isinstance(target_type, sa.Numeric) and not isinstance(target_type, (sa.Float, sa.Double)): + # Since Decimal is stored in JSON either as String or as Float + string_value = f"Yson::ConvertTo({statement}, Optional, Yson::Options(true AS AutoConvert))" + return f"CAST({string_value} AS Optional<{type_name}>)" + return f"Yson::ConvertTo({statement}, Optional<{type_name}>)" + + def _is_bound_to_nullable_column(self, bind_name: str) -> bool: + if bind_name in self.column_keys and hasattr(self.compile_state, "dml_table"): + if bind_name in self.compile_state.dml_table.c: + column = self.compile_state.dml_table.c[bind_name] + return column.nullable and not column.primary_key + return False + + def _guess_bound_variable_type_by_parameters( + self, bind: sa.BindParameter, post_compile_bind_values: list + ) -> Optional[sa.types.TypeEngine]: + bind_type = bind.type + if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values): + not_null_values = [v for v in post_compile_bind_values if v is not None] + if not_null_values: + bind_type = sa.BindParameter("", not_null_values[0]).type + + if isinstance(bind_type, sa.types.NullType): + return None + + return bind_type + + def _get_expanding_bind_names(self, bind_name: str, parameters_values: Mapping[str, List[Any]]) -> List[Any]: + expanding_bind_names = [] + for parameter_name in parameters_values: + parameter_bind_name = "_".join(parameter_name.split("_")[:-1]) + if parameter_bind_name == bind_name: + expanding_bind_names.append(parameter_name) + return expanding_bind_names + + def render_bind_cast(self, type_, dbapi_type, sqltext): + pass + + def get_bind_types( + self, post_compile_parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] + ) -> Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]]: + """ + This method extracts information about bound variables from the table definition and parameters. + """ + if isinstance(post_compile_parameters, collections.abc.Mapping): + post_compile_parameters = [post_compile_parameters] + + parameters_values = collections.defaultdict(list) + for parameters_entry in post_compile_parameters: + for parameter_name, parameter_value in parameters_entry.items(): + parameters_values[parameter_name].append(parameter_value) + + parameter_types = {} + for bind_name in self.bind_names.values(): + bind = self.binds[bind_name] + + if bind.literal_execute: + continue + + if not bind.expanding: + post_compile_bind_names = [bind_name] + post_compile_bind_values = parameters_values[bind_name] + else: + post_compile_bind_names = self._get_expanding_bind_names(bind_name, parameters_values) + post_compile_bind_values = [] + for parameter_name, parameter_values in parameters_values.items(): + if parameter_name in post_compile_bind_names: + post_compile_bind_values.extend(parameter_values) + + is_optional = self._is_bound_to_nullable_column(bind_name) + if not post_compile_bind_values or None in post_compile_bind_values: + is_optional = True + + bind_type = self._guess_bound_variable_type_by_parameters(bind, post_compile_bind_values) + + if bind_type: + for post_compile_bind_name in post_compile_bind_names: + parameter_types[post_compile_bind_name] = self._type_compiler_cls(self.dialect).get_ydb_type( + bind_type, is_optional + ) + + return parameter_types + + def _maybe_cast( + self, + element: Any, + cast_to: Type[sa.types.TypeEngine], + skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, + ) -> Any: + if not skip_types: + skip_types = (cast_to,) + if cast_to not in skip_types: + skip_types = (*skip_types, cast_to) + if not hasattr(element, "type") or not isinstance(element.type, skip_types): + return sa.Cast(element, cast_to) + return element + + def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw): + return self.visit_insert(insert_stmt, visited_bindparam, **kw).replace("INSERT", "UPSERT", 1) + + +class YqlDDLCompiler(BaseYqlDDLCompiler): + ... diff --git a/ydb_sqlalchemy/sqlalchemy/datetime_types.py b/ydb_sqlalchemy/sqlalchemy/datetime_types.py index d2f8283..6cd10cb 100644 --- a/ydb_sqlalchemy/sqlalchemy/datetime_types.py +++ b/ydb_sqlalchemy/sqlalchemy/datetime_types.py @@ -1,13 +1,11 @@ import datetime from typing import Optional -from sqlalchemy import Dialect from sqlalchemy import types as sqltypes -from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType class YqlTimestamp(sqltypes.TIMESTAMP): - def result_processor(self, dialect: Dialect, coltype: str) -> Optional[_ResultProcessorType[datetime.datetime]]: + def result_processor(self, dialect, coltype): def process(value: Optional[datetime.datetime]) -> Optional[datetime.datetime]: if value is None: return None @@ -19,7 +17,7 @@ def process(value: Optional[datetime.datetime]) -> Optional[datetime.datetime]: class YqlDateTime(YqlTimestamp, sqltypes.DATETIME): - def bind_processor(self, dialect: Dialect) -> Optional[_BindProcessorType[datetime.datetime]]: + def bind_processor(self, dialect): def process(value: Optional[datetime.datetime]) -> Optional[int]: if value is None: return None diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 557ce3d..34e26b6 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -1,6 +1,13 @@ from typing import Any, Mapping, Type, Union -from sqlalchemy import ARRAY, ColumnElement, exc, types +from sqlalchemy import __version__ as sa_version + +if sa_version.startswith("2."): + from sqlalchemy import ColumnElement +else: + from sqlalchemy.sql.expression import ColumnElement + +from sqlalchemy import ARRAY, exc, types from sqlalchemy.sql import type_api from .datetime_types import YqlDateTime, YqlTimestamp # noqa: F401