diff --git a/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py b/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py index 60a999ade..64152ba6f 100644 --- a/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py +++ b/lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py @@ -1,4 +1,5 @@ from collections import defaultdict +import functools from http import HTTPStatus from typing import Iterable @@ -70,7 +71,7 @@ def test_lod_fixed_single_dim_in_two_dim_query(self, control_api, data_api, save assert float(row[3]) == pytest.approx(sum_by_city[row[0]]) assert float(row[4]) == pytest.approx(sum_by_category[row[1]]) - def test_null_dimensions(self, control_api, data_api, db, saved_connection_id): + def test_null_dimensions(self, request, control_api, data_api, db, saved_connection_id): connection_id = saved_connection_id raw_data = [ @@ -88,6 +89,7 @@ def test_null_dimensions(self, control_api, data_api, db, saved_connection_id): C("sales", UserDataType.integer, vg=lambda rn, **kwargs: raw_data[rn]["sales"]), ] db_table = make_table(db, columns=columns, rows=len(raw_data)) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) ds = Dataset() ds.sources["source_1"] = ds.source(connection_id=connection_id, **data_source_settings_from_table(db_table)) @@ -165,12 +167,14 @@ def get_total_value() -> float: def test_total_lod_2( self, + request, control_api, data_api, saved_connection_id, db, ): db_table = make_table(db=db) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) ds = create_basic_dataset( api_v1=control_api, connection_id=saved_connection_id, @@ -213,12 +217,14 @@ def check_equality_of_totals(*field_names: str) -> None: def test_lod_in_order_by( self, + request, control_api, data_api, saved_connection_id, db, ): db_table = make_table(db=db) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) data_api = data_api ds = create_basic_dataset( @@ -258,8 +264,9 @@ def get_data(order_by: list[ResultField]) -> list: class DefaultBasicLookupFunctionTestSuite(DataApiTestBase, DatasetTestBase, DbServiceFixtureTextClass): - def test_ago_any_db(self, saved_connection_id, control_api, data_api, db): + def test_ago_any_db(self, request, saved_connection_id, control_api, data_api, db): db_table = make_table(db=db) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) ds = create_basic_dataset( api_v1=control_api, connection_id=saved_connection_id, @@ -287,8 +294,9 @@ def test_ago_any_db(self, saved_connection_id, control_api, data_api, db): check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=2, day_offset=2) - def test_triple_ago_any_db(self, saved_connection_id, control_api, data_api, db): + def test_triple_ago_any_db(self, request, saved_connection_id, control_api, data_api, db): db_table = make_table(db) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) ds = create_basic_dataset( api_v1=control_api, connection_id=saved_connection_id, @@ -322,10 +330,17 @@ def test_triple_ago_any_db(self, saved_connection_id, control_api, data_api, db) check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=3, day_offset=2) check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=4, day_offset=3) - def test_ago_any_db_multisource(self, saved_connection_id, control_api, data_api, db): + def test_ago_any_db_multisource(self, request, saved_connection_id, control_api, data_api, db): connection_id = saved_connection_id table_1 = make_table(db) table_2 = make_table(db) + + def teardown(db, *tables): + for table in tables: + db.drop_table(table) + + request.addfinalizer(functools.partial(teardown, db, table_1.table, table_2.table)) + ds = Dataset() ds.sources["source_1"] = ds.source( connection_id=connection_id, @@ -369,8 +384,9 @@ def test_ago_any_db_multisource(self, saved_connection_id, control_api, data_api check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=2, day_offset=2) - def test_nested_ago(self, saved_connection_id, control_api, data_api, db): + def test_nested_ago(self, request, saved_connection_id, control_api, data_api, db): db_table = make_table(db) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) ds = create_basic_dataset( api_v1=control_api, connection_id=saved_connection_id, @@ -404,8 +420,9 @@ def test_nested_ago(self, saved_connection_id, control_api, data_api, db): check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=3, day_offset=2) check_ago_data(data_rows=data_rows, date_idx=0, value_idx=1, ago_idx=4, day_offset=3) - def test_month_ago_for_shorter_month(self, db, saved_connection_id, control_api, data_api): + def test_month_ago_for_shorter_month(self, request, db, saved_connection_id, control_api, data_api): any_db_table_200 = make_table(db, rows=200) + request.addfinalizer(functools.partial(db.drop_table, any_db_table_200.table)) # FIXME # if any_db.conn_type == CONNECTION_TYPE_ORACLE: diff --git a/lib/dl_api_lib_testing/dl_api_lib_testing/connector/data_api_suites.py b/lib/dl_api_lib_testing/dl_api_lib_testing/connector/data_api_suites.py index edd900b9b..702dd445a 100644 --- a/lib/dl_api_lib_testing/dl_api_lib_testing/connector/data_api_suites.py +++ b/lib/dl_api_lib_testing/dl_api_lib_testing/connector/data_api_suites.py @@ -1,4 +1,5 @@ import datetime +import functools import json from typing import Any @@ -60,6 +61,7 @@ def test_basic_result( def _test_contains( self, + request, db: Db, saved_connection_id: str, dataset_params: dict, @@ -82,6 +84,8 @@ def _test_contains( ), ] db_table = make_table(db, columns=columns) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) + params = self.get_dataset_params(dataset_params, db_table) ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params) @@ -156,6 +160,7 @@ def test_array_not_contains_filter( @for_features(array_support) def test_array_contains_field( self, + request, db: Db, saved_connection_id: str, dataset_params: dict, @@ -179,6 +184,7 @@ def test_array_contains_field( ), ] db_table = make_table(db, columns=columns) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) params = self.get_dataset_params(dataset_params, db_table) ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params) @@ -384,6 +390,7 @@ def test_distinct_with_nonexistent_filter( def test_date_filter_distinct( self, + request, db: Db, saved_connection_id: str, dataset_params: dict, @@ -398,6 +405,8 @@ def test_date_filter_distinct( {"date_val": datetime.date(2023, 4, 2)}, ] db_table = make_table(db, columns=columns, data=data) + request.addfinalizer(functools.partial(db.drop_table, db_table.table)) + params = self.get_dataset_params(dataset_params, db_table) ds = self.make_basic_dataset(control_api, connection_id=saved_connection_id, dataset_params=params) diff --git a/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/base.py b/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/base.py index af30596e9..db5506f3c 100644 --- a/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/base.py +++ b/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/base.py @@ -143,6 +143,15 @@ async def s3_bucket(self, s3_client: AsyncS3Client) -> str: def sample_table_spec(self) -> FixtureTableSpec: return attr.evolve(TABLE_SPEC_SAMPLE_SUPERSTORE, nullable=True) + def _get_s3_func_schema_for_table(self, table: DbTable) -> str: + field_id_gen = get_field_id_generator(self.conn_type) + tbl_schema = ", ".join( + "{} {}".format(field_id_gen.make_field_id(dict(title=col.name, index=idx)), col.type.compile()) + for idx, col in enumerate(table.table.columns) + ) + tbl_schema = tbl_schema.replace("()", "") # String() -> String: type arguments are not needed here + return tbl_schema + @pytest.fixture(scope="function") async def sample_s3_file( self, @@ -153,13 +162,7 @@ async def sample_s3_file( ) -> AsyncGenerator[str, None]: filename = f"my_file_{uuid.uuid4()}.native" - field_id_gen = get_field_id_generator(self.conn_type) - tbl_schema = ", ".join( - "{} {}".format(field_id_gen.make_field_id(dict(title=col.name, index=idx)), col.type.compile()) - for idx, col in enumerate(sample_table.table.columns) - ) - tbl_schema = tbl_schema.replace("()", "") # String() -> String: type arguments are not needed here - + tbl_schema = self._get_s3_func_schema_for_table(sample_table) create_s3_native_from_ch_table(filename, s3_bucket, s3_settings, sample_table, tbl_schema) yield filename diff --git a/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/connection_executor.py b/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/connection_executor.py index 55da78e3f..8bc6f8511 100644 --- a/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/connection_executor.py +++ b/lib/dl_connector_bundle_chs3/dl_connector_bundle_chs3_tests/db/base/core/connection_executor.py @@ -1,14 +1,23 @@ import abc import pytest +import sqlalchemy as sa +from dl_configs.settings_submodels import S3Settings +from dl_constants.enums import UserDataType +from dl_core.connection_executors import ( + AsyncConnExecutorBase, + ConnExecutorQuery, +) from dl_core.connection_models import DBIdent +from dl_core_testing.database import DbTable from dl_core_testing.testcases.connection_executor import ( DefaultAsyncConnectionExecutorTestSuite, DefaultSyncAsyncConnectionExecutorCheckBase, DefaultSyncConnectionExecutorTestSuite, ) from dl_testing.regulated_test import RegulatedTestParams +from dl_testing.s3_utils import s3_tbl_func_maker from dl_connector_bundle_chs3_tests.db.base.core.base import ( FILE_CONN_TV, @@ -43,6 +52,7 @@ class CHS3SyncConnectionExecutorTestBase( test_params = RegulatedTestParams( mark_tests_skipped={ DefaultSyncConnectionExecutorTestSuite.test_type_recognition: "Not implemented", + DefaultSyncConnectionExecutorTestSuite.test_get_table_names: "Not implemented", }, ) @@ -52,4 +62,67 @@ class CHS3AsyncConnectionExecutorTestBase( DefaultAsyncConnectionExecutorTestSuite[FILE_CONN_TV], metaclass=abc.ABCMeta, ): - pass + async def test_select_data( + self, + sample_table: DbTable, + saved_connection: FILE_CONN_TV, + async_connection_executor: AsyncConnExecutorBase, + s3_settings: S3Settings, + sample_s3_file: str, + ) -> None: + schema_line = self._get_s3_func_schema_for_table(sample_table) + s3_tbl_func = s3_tbl_func_maker(s3_settings)( + for_="dba", + conn_dto=saved_connection.get_conn_dto(), + filename=sample_s3_file, + file_fmt="Native", + schema_line=schema_line, + ) + file_columns = [sa.column(col_desc.split(" ")[0]) for col_desc in schema_line.split(", ")] + # ^ "col1 String, col2 Int32, col3 Date32" -> [col1, col2, col3] + + n_rows = 3 + result = await async_connection_executor.execute( + ConnExecutorQuery( + query=sa.select(columns=file_columns) + .select_from(sa.text(s3_tbl_func)) + .order_by(file_columns[0]) + .limit(n_rows), + chunk_size=6, + ) + ) + rows = await result.get_all() + assert len(rows) == n_rows + + @pytest.mark.asyncio + async def test_cast_row_to_output( + self, + sample_table: DbTable, + saved_connection: FILE_CONN_TV, + async_connection_executor: AsyncConnExecutorBase, + s3_settings: S3Settings, + sample_s3_file: str, + ) -> None: + schema_line = self._get_s3_func_schema_for_table(sample_table) + s3_tbl_func = s3_tbl_func_maker(s3_settings)( + for_="dba", + conn_dto=saved_connection.get_conn_dto(), + filename=sample_s3_file, + file_fmt="Native", + schema_line=schema_line, + ) + + result = await async_connection_executor.execute( + ConnExecutorQuery( + sa.select(columns=[sa.literal(1), sa.literal(2), sa.literal(3)]) + .select_from(sa.text(s3_tbl_func)) + .limit(1), + user_types=[ + UserDataType.boolean, + UserDataType.float, + UserDataType.integer, + ], + ) + ) + rows = await result.get_all() + assert rows == [(True, 2.0, 3)], rows diff --git a/lib/dl_connector_clickhouse/dl_connector_clickhouse_tests/db/core/test_connection_executor.py b/lib/dl_connector_clickhouse/dl_connector_clickhouse_tests/db/core/test_connection_executor.py index 5b106765c..23383676b 100644 --- a/lib/dl_connector_clickhouse/dl_connector_clickhouse_tests/db/core/test_connection_executor.py +++ b/lib/dl_connector_clickhouse/dl_connector_clickhouse_tests/db/core/test_connection_executor.py @@ -1,14 +1,40 @@ import asyncio +import enum import os -from typing import Optional +from typing import ( + Optional, + Sequence, +) +import attr +from clickhouse_sqlalchemy import types as ch_types import pytest +import sqlalchemy as sa +from dl_constants.enums import ( + ConnectionType, + UserDataType, +) from dl_core.connection_executors import ( AsyncConnExecutorBase, + ConnExecutorQuery, SyncConnExecutorBase, ) -from dl_core.connection_models.common_models import DBIdent +from dl_core.connection_models.common_models import ( + DBIdent, + SchemaIdent, +) +from dl_core.db.native_type import ( + ClickHouseDateTime64WithTZNativeType, + ClickHouseDateTimeWithTZNativeType, + ClickHouseNativeType, + GenericNativeType, + norm_native_type, +) +from dl_core_testing.database import ( + Db, + DbTable, +) from dl_core_testing.testcases.connection_executor import ( DefaultAsyncConnectionExecutorTestSuite, DefaultSyncAsyncConnectionExecutorCheckBase, @@ -47,7 +73,128 @@ class TestClickHouseSyncConnectionExecutor( ClickHouseSyncAsyncConnectionExecutorCheckBase, DefaultSyncConnectionExecutorTestSuite[ConnectionClickhouse], ): - pass + @attr.s(frozen=True) + class CD(DefaultSyncConnectionExecutorTestSuite.CD): + def get_expected_native_type(self, conn_type: ConnectionType) -> GenericNativeType: + actual_type = ( + self.sa_type.nested_type + if isinstance(self.sa_type, (ch_types.Nullable, ch_types.LowCardinality)) + else self.sa_type + ) + return self.nt or ClickHouseNativeType( + conn_type=conn_type, + name=norm_native_type(self.nt_name if self.nt_name is not None else actual_type), + nullable=isinstance(self.sa_type, ch_types.Nullable), # note: self.nullable is not taken into account + lowcardinality=False, + ) + + def get_schemas_for_type_recognition(self) -> dict[str, Sequence[DefaultSyncConnectionExecutorTestSuite.CD]]: + enum_values = ["allchars '\"\t,= etc", "test", "value1"] + tst_enum8 = enum.Enum("TstEnum8", enum_values) + tst_enum16 = enum.Enum("TstEnum16", enum_values) + + return { + "ch_types_numbers": [ + self.CD(ch_types.Nullable(ch_types.Int8()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.Int16()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.Int32()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.Int64()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.UInt8()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.UInt16()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.UInt32()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.UInt64()), UserDataType.integer), + self.CD(ch_types.Nullable(ch_types.Float32()), UserDataType.float), + self.CD(ch_types.Nullable(ch_types.Float64()), UserDataType.float), + self.CD(ch_types.Decimal(8, 4), UserDataType.float, nt_name="float"), + ], + "ch_types_string": [ + self.CD(sa.String(length=256), UserDataType.string), + self.CD(ch_types.Nullable(ch_types.String()), UserDataType.string), + # Note: `Nullable(LowCardinality(String))` is actually not allowed: + # "Nested type LowCardinality(String) cannot be inside Nullable type" + self.CD( + ch_types.LowCardinality(ch_types.Nullable(ch_types.String())), + UserDataType.string, + nt=ClickHouseNativeType( + conn_type=self.conn_type, + name="string", + nullable=True, + lowcardinality=True, + ), + ), + self.CD( + ch_types.LowCardinality(ch_types.String()), + UserDataType.string, + nt=ClickHouseNativeType( + conn_type=self.conn_type, + name="string", + nullable=False, + lowcardinality=True, + ), + ), + ], + "ch_types_date": [ + # not nullable so we can check 0000-00-00 + self.CD(ch_types.Date(), UserDataType.date), + self.CD(ch_types.Date32(), UserDataType.date, nt_name="date"), + # not nullable so we can check 0000-00-00 00:00:00 + self.CD( + ch_types.DateTime(), + UserDataType.genericdatetime, + nt=ClickHouseDateTimeWithTZNativeType( + conn_type=self.conn_type, + name="datetimewithtz", + nullable=False, + lowcardinality=False, + timezone_name="UTC", # the CH system timezone + explicit_timezone=False, + ), + ), + self.CD( + ch_types.DateTimeWithTZ("Europe/Moscow"), + UserDataType.genericdatetime, + nt=ClickHouseDateTimeWithTZNativeType( + conn_type=self.conn_type, + name="datetimewithtz", + nullable=False, + lowcardinality=False, + timezone_name="Europe/Moscow", + explicit_timezone=True, + ), + ), + self.CD( + ch_types.DateTime64(6), + UserDataType.genericdatetime, + nt=ClickHouseDateTime64WithTZNativeType( + conn_type=self.conn_type, + name="datetime64withtz", + nullable=False, + lowcardinality=False, + precision=6, + timezone_name="UTC", # the CH system timezone + explicit_timezone=False, + ), + ), + self.CD( + ch_types.DateTime64WithTZ(6, "Europe/Moscow"), + UserDataType.genericdatetime, + nt=ClickHouseDateTime64WithTZNativeType( + conn_type=self.conn_type, + name="datetime64withtz", + nullable=False, + lowcardinality=False, + precision=6, + timezone_name="Europe/Moscow", + explicit_timezone=True, + ), + ), + ], + "ch_types_other": [ + self.CD(ch_types.Enum8(tst_enum8), UserDataType.string, nt_name="string"), + self.CD(ch_types.Enum16(tst_enum16), UserDataType.string, nt_name="string"), + self.CD(ch_types.Nullable(ch_types.Bool()), UserDataType.boolean), + ], + } class TestClickHouseAsyncConnectionExecutor( @@ -65,6 +212,18 @@ class TestClickHouseAsyncConnectionExecutor( }, ) + @pytest.mark.asyncio + async def test_sa_mod(self, async_connection_executor: AsyncConnExecutorBase) -> None: + result = await async_connection_executor.execute(ConnExecutorQuery(sa.select([sa.literal(3) % sa.literal(2)]))) + rows = await result.get_all() + assert rows == [(1,)] + + @pytest.mark.asyncio + async def test_inf(self, async_connection_executor: AsyncConnExecutorBase) -> None: + result = await async_connection_executor.execute(ConnExecutorQuery("select 1 / 0")) + rows = await result.get_all() + assert rows == [(None,)] + @pytest.mark.skipif(os.environ.get("WE_ARE_IN_CI"), reason="can't use localhost") class TestSslClickHouseSyncConnectionExecutor( diff --git a/lib/dl_connector_postgresql/dl_connector_postgresql_tests/db/core/test_connection_executor.py b/lib/dl_connector_postgresql/dl_connector_postgresql_tests/db/core/test_connection_executor.py index cdd3886c1..a5d81980e 100644 --- a/lib/dl_connector_postgresql/dl_connector_postgresql_tests/db/core/test_connection_executor.py +++ b/lib/dl_connector_postgresql/dl_connector_postgresql_tests/db/core/test_connection_executor.py @@ -4,19 +4,15 @@ Optional, Sequence, ) +import uuid import pytest -import shortuuid -import sqlalchemy as sa -from sqlalchemy.types import TypeEngine +from sqlalchemy.dialects import postgresql as pg_types from dl_constants.enums import UserDataType from dl_core.connection_executors import AsyncConnExecutorBase from dl_core.connection_executors.sync_base import SyncConnExecutorBase -from dl_core.connection_models.common_models import ( - DBIdent, - TableIdent, -) +from dl_core.connection_models.common_models import DBIdent from dl_core_testing.database import Db from dl_core_testing.testcases.connection_executor import ( DefaultAsyncConnectionExecutorTestSuite, @@ -52,41 +48,12 @@ def check_db_version(self, db_version: Optional[str]) -> None: assert db_version is not None assert "." in db_version - def get_schemas_for_type_recognition(self) -> dict[str, Sequence[tuple[TypeEngine, UserDataType]]]: - return { - "types_postgres": [ - (sa.Integer(), UserDataType.integer), - (sa.Float(), UserDataType.float), - (sa.String(length=256), UserDataType.string), - (sa.Date(), UserDataType.date), - (sa.DateTime(), UserDataType.genericdatetime), - (CITEXT(), UserDataType.string), - ], - } - @pytest.fixture(scope="function") def enabled_citext_extension(self, db: Db) -> None: db.execute("CREATE EXTENSION IF NOT EXISTS CITEXT;") - def test_type_recognition( - self, db: Db, sync_connection_executor: SyncConnExecutorBase, enabled_citext_extension - ) -> None: - for schema_name, type_schema in sorted(self.get_schemas_for_type_recognition().items()): - columns = [ - sa.Column(name=f"c_{shortuuid.uuid().lower()}", type_=sa_type) for sa_type, user_type in type_schema - ] - sa_table = db.table_from_columns(columns=columns) - db.create_table(sa_table) - table_def = TableIdent(db_name=db.name, schema_name=sa_table.schema, table_name=sa_table.name) - detected_columns = sync_connection_executor.get_table_schema_info(table_def=table_def).schema - assert len(detected_columns) == len(type_schema), f"Incorrect number of columns in schema {schema_name}" - for col_idx, ((_sa_type, user_type), detected_col) in enumerate(zip(type_schema, detected_columns)): - assert detected_col.user_type == user_type, ( - f"Incorrect user type detected for schema {schema_name} col #{col_idx}: " - f"expected {user_type.name}, got {detected_col.user_type.name}" - ) - +@pytest.mark.usefixtures("enabled_citext_extension") class TestPostgreSQLSyncConnectionExecutor( PostgreSQLSyncAsyncConnectionExecutorCheckBase, DefaultSyncConnectionExecutorTestSuite[ConnectionPostgreSQL], @@ -97,6 +64,33 @@ class TestPostgreSQLSyncConnectionExecutor( }, ) + def get_schemas_for_type_recognition(self) -> dict[str, Sequence[DefaultSyncConnectionExecutorTestSuite.CD]]: + return { + "types_postgres_number": [ + self.CD(pg_types.SMALLINT(), UserDataType.integer), + self.CD(pg_types.INTEGER(), UserDataType.integer), + self.CD(pg_types.BIGINT(), UserDataType.integer), + self.CD(pg_types.REAL(), UserDataType.float), + self.CD(pg_types.DOUBLE_PRECISION(), UserDataType.float), + self.CD(pg_types.NUMERIC(), UserDataType.float), + ], + "types_postgres_string": [ + self.CD(pg_types.CHAR(), UserDataType.string), + self.CD(pg_types.VARCHAR(100), UserDataType.string), + self.CD(pg_types.TEXT(), UserDataType.string), + self.CD(CITEXT(), UserDataType.string), + ], + "types_postgres_date": [ + self.CD(pg_types.DATE(), UserDataType.date), + self.CD(pg_types.TIMESTAMP(timezone=False), UserDataType.genericdatetime), + self.CD(pg_types.TIMESTAMP(timezone=True), UserDataType.genericdatetime), + ], + "types_postgres_other": [ + self.CD(pg_types.BOOLEAN(), UserDataType.boolean), + self.CD(pg_types.ENUM("var1", "var2", name=str(uuid.uuid4())), UserDataType.string), + ], + } + class TestPostgreSQLAsyncConnectionExecutor( PostgreSQLSyncAsyncConnectionExecutorCheckBase, diff --git a/lib/dl_core_testing/dl_core_testing/testcases/connection_executor.py b/lib/dl_core_testing/dl_core_testing/testcases/connection_executor.py index 9e6361749..6702ed4fb 100644 --- a/lib/dl_core_testing/dl_core_testing/testcases/connection_executor.py +++ b/lib/dl_core_testing/dl_core_testing/testcases/connection_executor.py @@ -2,10 +2,12 @@ import asyncio import contextlib +import functools from typing import ( TYPE_CHECKING, AsyncGenerator, Callable, + ClassVar, Generator, Generic, Optional, @@ -13,17 +15,28 @@ TypeVar, ) +import attr import pytest import shortuuid import sqlalchemy as sa from sqlalchemy.types import TypeEngine -from dl_constants.enums import UserDataType +from dl_constants.enums import ( + ConnectionType, + UserDataType, +) from dl_core.connection_executors.common_base import ConnExecutorQuery from dl_core.connection_models.common_models import ( DBIdent, + SATextTableDefinition, + SchemaIdent, TableIdent, ) +from dl_core.db.native_type import ( + CommonNativeType, + GenericNativeType, + norm_native_type, +) import dl_core.exc as core_exc from dl_core.us_connection_base import ConnectionBase from dl_core_testing.database import ( @@ -84,7 +97,9 @@ def db_table_columns(self, db: Db) -> list[C]: @pytest.fixture(scope="function") def db_table(self, db: Db, db_table_columns: list[C]) -> DbTable: - return make_table(db, columns=db_table_columns) + db_table = make_table(db, columns=db_table_columns) + yield db_table + db.drop_table(db_table.table) @pytest.fixture(scope="function") def nonexistent_table_ident(self, existing_table_ident: TableIdent) -> TableIdent: @@ -117,6 +132,22 @@ def test_get_db_version(self, sync_connection_executor: SyncConnExecutorBase, db def test_test(self, sync_connection_executor: SyncConnExecutorBase) -> None: sync_connection_executor.test() + def test_get_table_names( + self, + sample_table: DbTable, + db: Db, + sync_connection_executor: SyncConnExecutorBase, + ): + # at the moment, checks that sample table is listed among the others + + tables = [sample_table] + expected_table_names = set(table.name for table in tables) + + actual_tables = sync_connection_executor.get_tables(SchemaIdent(db_name=db.name, schema_name=None)) + actual_table_names = [tid.table_name for tid in actual_tables] + + assert set(actual_table_names).issuperset(expected_table_names) + def test_table_exists( self, sync_connection_executor: SyncConnExecutorBase, @@ -132,33 +163,57 @@ def test_table_not_exists( ) -> None: assert not sync_connection_executor.is_table_exists(nonexistent_table_ident) - def get_schemas_for_type_recognition(self) -> dict[str, Sequence[tuple[TypeEngine, UserDataType]]]: + @attr.s(frozen=True) + class CD: + sa_type: TypeEngine = attr.ib() + # Expected data + user_type: UserDataType = attr.ib() + nullable: bool = attr.ib(default=True) + nt_name: Optional[str] = attr.ib(default=None) + nt: Optional[GenericNativeType] = attr.ib(default=None) + + def get_expected_native_type(self, conn_type: ConnectionType) -> GenericNativeType: + return self.nt or CommonNativeType( + conn_type=conn_type, + name=norm_native_type(self.nt_name if self.nt_name is not None else self.sa_type), + nullable=self.nullable, + ) + + def get_schemas_for_type_recognition(self) -> dict[str, Sequence[CD]]: return { "standard_types": [ - (sa.Integer(), UserDataType.integer), - (sa.Float(), UserDataType.float), - (sa.String(length=256), UserDataType.string), - (sa.Date(), UserDataType.date), - (sa.DateTime(), UserDataType.genericdatetime), + self.CD(sa.Integer(), UserDataType.integer), + self.CD(sa.Float(), UserDataType.float), + self.CD(sa.String(length=256), UserDataType.string), + self.CD(sa.Date(), UserDataType.date), + self.CD(sa.DateTime(), UserDataType.genericdatetime), ], } - def test_type_recognition(self, db: Db, sync_connection_executor: SyncConnExecutorBase) -> None: - for schema_name, type_schema in sorted(self.get_schemas_for_type_recognition().items()): + def test_type_recognition(self, request, db: Db, sync_connection_executor: SyncConnExecutorBase) -> None: + for schema_name, type_schema in self.get_schemas_for_type_recognition().items(): columns = [ - sa.Column(name=f"c_{shortuuid.uuid().lower()}", type_=sa_type) for sa_type, user_type in type_schema + sa.Column(name=f"c_{shortuuid.uuid().lower()}", type_=column_data.sa_type) + for column_data in type_schema ] sa_table = db.table_from_columns(columns=columns) + db.create_table(sa_table) + request.addfinalizer(functools.partial(db.drop_table, sa_table)) + detected_columns = sync_connection_executor.get_table_schema_info( table_def=TableIdent(db_name=db.name, schema_name=sa_table.schema, table_name=sa_table.name) ).schema assert len(detected_columns) == len(type_schema), f"Incorrect number of columns in schema {schema_name}" - for col_idx, ((_sa_type, user_type), detected_col) in enumerate(zip(type_schema, detected_columns)): - assert detected_col.user_type == user_type, ( + for col_idx, (expected_col, detected_col) in enumerate(zip(type_schema, detected_columns)): + assert detected_col.user_type == expected_col.user_type, ( f"Incorrect user type detected for schema {schema_name} col #{col_idx}: " - f"expected {user_type.name}, got {detected_col.user_type.name}" + f"expected {expected_col.user_type.name}, got {detected_col.user_type.name}" ) + expected_native_type = expected_col.get_expected_native_type(self.conn_type) + assert ( + detected_col.native_type == expected_native_type + ), f"Incorrect native type detected for schema {schema_name} col #{col_idx}: expected {repr(expected_native_type)}, got {repr(detected_col.native_type)}" def test_simple_select(self, sync_connection_executor: SyncConnExecutorBase) -> None: query = ConnExecutorQuery(query=sa.select([sa.literal(1)])) @@ -183,15 +238,25 @@ def test_closing_sql_sessions( for _i in range(5): sync_connection_executor.execute(ConnExecutorQuery(query=query_for_session_check)) + subselect_query_for_schema_test: ClassVar[str] = "(SELECT 1 AS num) AS source" + + @pytest.mark.parametrize("case", ["table", "subselect"]) def test_get_table_schema_info( self, + case: str, sync_connection_executor: SyncConnExecutorBase, existing_table_ident: TableIdent, db_table: DbTable, ) -> None: # Just tests that the adapter can successfully retrieve the schema. # Data source tests check this in more detail - detected_columns = sync_connection_executor.get_table_schema_info(table_def=existing_table_ident).schema + table_def = { + "table": existing_table_ident, + "subselect": SATextTableDefinition( + text=sa.sql.elements.TextClause(self.subselect_query_for_schema_test), + ), + }[case] + detected_columns = sync_connection_executor.get_table_schema_info(table_def=table_def).schema assert len(detected_columns) > 0 def test_get_table_schema_info_for_nonexistent_table( @@ -239,6 +304,40 @@ async def test_simple_select(self, async_connection_executor: AsyncConnExecutorB assert len(result) == 1 assert result[0] == (1,) + async def test_select_data(self, sample_table: DbTable, async_connection_executor: AsyncConnExecutorBase) -> None: + n_rows = 3 + result = await async_connection_executor.execute( + ConnExecutorQuery( + query=sa.select(columns=sample_table.table.columns) + .select_from(sample_table.table) + .order_by(sample_table.table.columns[0]) + .limit(n_rows), + chunk_size=6, + ) + ) + rows = await result.get_all() + assert len(rows) == n_rows + + async def test_cast_row_to_output( + self, + sample_table: DbTable, + async_connection_executor: AsyncConnExecutorBase, + ) -> None: + result = await async_connection_executor.execute( + ConnExecutorQuery( + sa.select(columns=[sa.literal(1), sa.literal(2), sa.literal(3)]) + .select_from(sample_table.table) + .limit(1), + user_types=[ + UserDataType.boolean, + UserDataType.float, + UserDataType.integer, + ], + ) + ) + rows = await result.get_all() + assert rows == [(True, 2.0, 3)], rows + async def test_error_on_select_from_nonexistent_source( self, db: Db,