Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to support sqlalchemy 1.4+ #60

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
sqlalchemy >= 2.0.7, < 3.0.0
sqlalchemy >= 1.4.0, < 3.0.0
ydb >= 3.18.8
ydb-dbapi >= 0.1.1
ydb-dbapi >= 0.1.2
56 changes: 30 additions & 26 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
from ydb_sqlalchemy import sqlalchemy as ydb_sa
from ydb_sqlalchemy.sqlalchemy import types

if sa.__version__ >= "2.":
from sqlalchemy import NullPool
from sqlalchemy import QueuePool
else:
from sqlalchemy.pool import NullPool
from sqlalchemy.pool import QueuePool


def clear_sql(stm):
return stm.replace("\n", " ").replace(" ", " ").strip()
Expand Down Expand Up @@ -94,7 +101,7 @@ def test_sa_crud(self, connection):
(5, "c"),
]

def test_cached_query(self, connection_no_trans: sa.Connection, connection: sa.Connection):
def test_cached_query(self, connection_no_trans, connection):
table = self.tables.test

with connection_no_trans.begin() as transaction:
Expand Down Expand Up @@ -249,7 +256,7 @@ def test_primitive_types(self, connection):
assert row == (42, "Hello World!", 3.5, True)

def test_integer_types(self, connection):
stmt = sa.Select(
stmt = sa.select(
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint8", 8, types.UInt8))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint16", 16, types.UInt16))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint32", 32, types.UInt32))),
Expand All @@ -263,8 +270,8 @@ def test_integer_types(self, connection):
result = connection.execute(stmt).fetchone()
assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64")

def test_datetime_types(self, connection: sa.Connection):
stmt = sa.Select(
def test_datetime_types(self, connection):
stmt = sa.select(
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_datetime", datetime.datetime.now(), sa.DateTime))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_DATETIME", datetime.datetime.now(), sa.DATETIME))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_TIMESTAMP", datetime.datetime.now(), sa.TIMESTAMP))),
Expand All @@ -273,7 +280,7 @@ def test_datetime_types(self, connection: sa.Connection):
result = connection.execute(stmt).fetchone()
assert result == (b"Timestamp", b"Datetime", b"Timestamp")

def test_datetime_types_timezone(self, connection: sa.Connection):
def test_datetime_types_timezone(self, connection):
table = self.tables.test_datetime_types
tzinfo = datetime.timezone(datetime.timedelta(hours=3, minutes=42))

Expand Down Expand Up @@ -476,7 +483,8 @@ def define_tables(cls, metadata: sa.MetaData):
Column("id", Integer, primary_key=True),
)

def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection):
@pytest.mark.skipif(sa.__version__ < "2.", reason="Something was different in SA<2, good to fix")
def test_rollback(self, connection_no_trans, connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
Expand All @@ -491,7 +499,7 @@ def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Conne
result = cursor.fetchall()
assert result == []

def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection):
def test_commit(self, connection_no_trans, connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
Expand All @@ -506,9 +514,7 @@ def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connect
assert set(result) == {(3,), (4,)}

@pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY))
def test_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
def test_interactive_transaction(self, connection_no_trans, connection, isolation_level):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

Expand All @@ -535,9 +541,7 @@ def test_interactive_transaction(
IsolationLevel.AUTOCOMMIT,
),
)
def test_not_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
def test_not_interactive_transaction(self, connection_no_trans, connection, isolation_level):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

Expand Down Expand Up @@ -573,7 +577,7 @@ class IsolationSettings(NamedTuple):
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.QuerySnapshotReadOnly().name, True),
}

def test_connection_set(self, connection_no_trans: sa.Connection):
def test_connection_set(self, connection_no_trans):
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items():
Expand Down Expand Up @@ -614,8 +618,8 @@ def ydb_pool(self, ydb_driver):
session_pool.stop()

def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine1 = sa.create_engine(config.db_url, poolclass=QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=QueuePool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
Expand All @@ -629,8 +633,8 @@ def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
assert not ydb_driver._stopped

def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine1 = sa.create_engine(config.db_url, poolclass=NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=NullPool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
Expand Down Expand Up @@ -861,7 +865,7 @@ def test_insert_in_name_and_field(self, connection):
class TestSecondaryIndex(TestBase):
__backend__ = True

def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
def test_column_indexes(self, connection, metadata: sa.MetaData):
table = Table(
"test_column_indexes/table",
metadata,
Expand All @@ -884,7 +888,7 @@ def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
index1 = indexes_map["ix_test_column_indexes_table_index_col2"]
assert index1.index_columns == ["index_col2"]

def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
def test_async_index(self, connection, metadata: sa.MetaData):
table = Table(
"test_async_index/table",
metadata,
Expand All @@ -903,7 +907,7 @@ def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
assert set(index.index_columns) == {"index_col1", "index_col2"}
# TODO: Check type after https://github.com/ydb-platform/ydb-python-sdk/issues/351

def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
def test_cover_index(self, connection, metadata: sa.MetaData):
table = Table(
"test_cover_index/table",
metadata,
Expand All @@ -922,7 +926,7 @@ def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
assert set(index.index_columns) == {"index_col1"}
# TODO: Check covered columns after https://github.com/ydb-platform/ydb-python-sdk/issues/409

def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaData):
def test_indexes_reflection(self, connection, metadata: sa.MetaData):
table = Table(
"test_indexes_reflection/table",
metadata,
Expand All @@ -948,7 +952,7 @@ def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaDa
"test_async_cover_index": {"index_col1"},
}

def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_simple_usage(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_simple_usage/persons",
metadata,
Expand Down Expand Up @@ -979,7 +983,7 @@ def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaDa
cursor = connection.execute(select_stmt)
assert cursor.scalar_one() == "Sarah Connor"

def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_with_join_usage(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_with_join_usage/persons",
metadata,
Expand Down Expand Up @@ -1033,7 +1037,7 @@ def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.Met
cursor = connection.execute(select_stmt)
assert cursor.one() == ("Sarah Connor", "wanted")

def test_index_deletion(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_deletion(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_deletion/persons",
metadata,
Expand Down Expand Up @@ -1062,7 +1066,7 @@ def define_tables(cls, metadata: sa.MetaData):
Table("table", metadata, sa.Column("id", sa.Integer, primary_key=True))

@classmethod
def insert_data(cls, connection: sa.Connection):
def insert_data(cls, connection):
table = cls.tables["some_dir/nested_dir/table"]
root_table = cls.tables["table"]

Expand Down
60 changes: 35 additions & 25 deletions test/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest
from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest
from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest

from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest
from sqlalchemy.testing.suite.test_types import StringTest as _StringTest
from sqlalchemy.testing.suite.test_types import (
Expand All @@ -78,14 +78,16 @@
TimestampMicrosecondsTest as _TimestampMicrosecondsTest,
)
from sqlalchemy.testing.suite.test_types import TimeTest as _TimeTest
from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest

from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types

test_types_suite = sqlalchemy.testing.suite.test_types
col_creator = test_types_suite.Column


OLD_SA = sa.__version__ < "2."


def column_getter(*args, **kwargs):
col = col_creator(*args, **kwargs)
if col.name == "x":
Expand Down Expand Up @@ -275,30 +277,35 @@ class BinaryTest(_BinaryTest):
pass


class TrueDivTest(_TrueDivTest):
@pytest.mark.skip("Unsupported builtin: FLOOR")
def test_floordiv_numeric(self, connection, left, right, expected):
pass
if not OLD_SA:
from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest

@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer(self, connection, left, right, expected):
pass
class TrueDivTest(_TrueDivTest):
@pytest.mark.skip("Unsupported builtin: FLOOR")
def test_floordiv_numeric(self, connection, left, right, expected):
pass

@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer_bound(self, connection):
pass
@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer(self, connection, left, right, expected):
pass

@pytest.mark.skip("Numeric is not Decimal")
def test_truediv_numeric(self):
# SqlAlchemy maybe eat Decimal and throw Double
pass
@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer_bound(self, connection):
pass

@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
def test_truediv_float(self, connection, left, right, expected):
eq_(
connection.scalar(select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))),
expected,
)
@pytest.mark.skip("Numeric is not Decimal")
def test_truediv_numeric(self):
# SqlAlchemy maybe eat Decimal and throw Double
pass

@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
def test_truediv_float(self, connection, left, right, expected):
eq_(
connection.scalar(
select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))
),
expected,
)


class ExistsTest(_ExistsTest):
Expand Down Expand Up @@ -539,9 +546,12 @@ def test_from_as_table(self, connection):
eq_(connection.execute(sa.select(table)).fetchall(), [(1,), (2,), (3,)])


@pytest.mark.skip("uuid unsupported for columns")
class NativeUUIDTest(_NativeUUIDTest):
pass
if not OLD_SA:
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest

@pytest.mark.skip("uuid unsupported for columns")
class NativeUUIDTest(_NativeUUIDTest):
pass


@pytest.mark.skip("unsupported Time data type")
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ max-line-length = 120
ignore=E203,W503
per-file-ignores =
ydb_sqlalchemy/__init__.py: F401
ydb_sqlalchemy/sqlalchemy/compiler/__init__.py: F401
exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,docs/*
Loading
Loading